28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/STLForwardCompat.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringRef.h"
35#include "llvm/ADT/TypeSwitch.h"
36#include "llvm/ADT/bit.h"
37#include "llvm/Support/InterleavedRange.h"
43#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
46#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
53 return attrs.empty() ?
nullptr : ArrayAttr::get(context, attrs);
67struct MemRefPointerLikeModel
68 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
71 return llvm::cast<MemRefType>(pointer).getElementType();
75struct LLVMPointerPointerLikeModel
76 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77 LLVM::LLVMPointerType> {
102 bool isRegionArgOfOp;
112 assert(isRegionArgOfOp &&
"Must describe a region operand");
115 size_t &getArgIdx() {
116 assert(isRegionArgOfOp &&
"Must describe a region operand");
121 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
125 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
128 bool isLoopOp()
const {
129 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
130 return isa<CanonicalLoopOp>(op);
132 Region *&getParentRegion() {
133 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
136 size_t &getLoopDepth() {
137 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
141 void skipIf(
bool v =
true) { skip = skip || v; }
159 llvm::ReversePostOrderTraversal<Block *> traversal(&r->
getBlocks().front());
162 size_t sequentialIdx = -1;
163 bool isOnlyContainerOp =
true;
164 for (
Block *
b : traversal) {
166 if (&op == o && !found) {
170 if (op.getNumRegions()) {
173 isOnlyContainerOp =
false;
175 if (found && !isOnlyContainerOp)
180 Component &containerOpInRegion = components.emplace_back();
181 containerOpInRegion.isRegionArgOfOp =
false;
182 containerOpInRegion.isUnique = isOnlyContainerOp;
183 containerOpInRegion.getContainerOp() = o;
184 containerOpInRegion.getOpPos() = sequentialIdx;
185 containerOpInRegion.getParentRegion() = r;
190 Component ®ionArgOfOperation = components.emplace_back();
191 regionArgOfOperation.isRegionArgOfOp =
true;
192 regionArgOfOperation.isUnique =
true;
193 regionArgOfOperation.getArgIdx() = 0;
194 regionArgOfOperation.getOwnerOp() = parent;
206 for (
auto [idx, region] : llvm::enumerate(o->
getRegions())) {
210 llvm_unreachable(
"Region not child of its parent operation");
212 regionArgOfOperation.isUnique =
false;
213 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
221 for (Component &c : components)
222 c.skipIf(c.isRegionArgOfOp && c.isUnique);
225 size_t numSurroundingLoops = 0;
226 for (Component &c : llvm::reverse(components)) {
231 if (c.isRegionArgOfOp) {
232 numSurroundingLoops = 0;
239 numSurroundingLoops = 0;
241 c.getLoopDepth() = numSurroundingLoops;
244 if (isa<CanonicalLoopOp>(c.getContainerOp()))
245 numSurroundingLoops += 1;
250 bool isLoopNest =
false;
251 for (Component &c : components) {
252 if (c.skip || c.isRegionArgOfOp)
255 if (!isLoopNest && c.getLoopDepth() >= 1) {
258 }
else if (isLoopNest) {
260 c.skipIf(c.isUnique);
264 if (c.getLoopDepth() == 0)
271 for (Component &c : components)
272 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
273 !isa<CanonicalLoopOp>(c.getContainerOp()));
277 bool newRegion =
true;
278 for (Component &c : llvm::reverse(components)) {
279 c.skipIf(newRegion && c.isUnique);
286 if (!c.isRegionArgOfOp && c.getContainerOp())
292 llvm::raw_svector_ostream NameOS(Name);
293 for (
auto &c : llvm::reverse(components)) {
297 if (c.isRegionArgOfOp)
298 NameOS <<
"_r" << c.getArgIdx();
299 else if (c.getLoopDepth() >= 1)
300 NameOS <<
"_d" << c.getLoopDepth();
302 NameOS <<
"_s" << c.getOpPos();
305 return NameOS.str().str();
308void OpenMPDialect::initialize() {
311#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
314#define GET_ATTRDEF_LIST
315#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
318#define GET_TYPEDEF_LIST
319#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
322 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
324 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
325 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
330 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
336 mlir::LLVM::GlobalOp::attachInterface<
339 mlir::LLVM::LLVMFuncOp::attachInterface<
342 mlir::func::FuncOp::attachInterface<
368 allocatorVars.push_back(operand);
369 allocatorTypes.push_back(type);
375 allocateVars.push_back(operand);
376 allocateTypes.push_back(type);
387 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
388 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
389 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
390 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
398template <
typename ClauseAttr>
400 using ClauseT =
decltype(std::declval<ClauseAttr>().getValue());
405 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
406 attr = ClauseAttr::get(parser.
getContext(), *enumValue);
409 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
412template <
typename ClauseAttr>
414 p << stringifyEnum(attr.getValue());
439 std::optional<omp::LinearModifier> linearModifier;
441 linearModifier = omp::LinearModifier::val;
443 linearModifier = omp::LinearModifier::ref;
445 linearModifier = omp::LinearModifier::uval;
448 bool hasLinearModifierParens = linearModifier.has_value();
449 if (hasLinearModifierParens && parser.
parseLParen())
457 if (hasLinearModifierParens && parser.
parseRParen())
460 linearVars.push_back(var);
461 linearTypes.push_back(type);
462 linearStepVars.push_back(stepVar);
463 linearStepTypes.push_back(stepType);
464 if (linearModifier) {
466 omp::LinearModifierAttr::get(parser.
getContext(), *linearModifier));
468 modifiers.push_back(UnitAttr::get(parser.
getContext()));
474 linearModifiers = ArrayAttr::get(parser.
getContext(), modifiers);
483 size_t linearVarsSize = linearVars.size();
484 for (
unsigned i = 0; i < linearVarsSize; ++i) {
488 Attribute modAttr = linearModifiers ? linearModifiers[i] :
nullptr;
489 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) :
nullptr;
491 p << omp::stringifyLinearModifier(mod.getValue()) <<
"(";
493 p << linearVars[i] <<
" : " << linearTypes[i];
494 p <<
" = " << linearStepVars[i] <<
" : " << stepVarTypes[i];
510 if (!linearModifiers)
512 if (linearModifiers->size() != linearVars.size())
514 <<
"expected as many linear modifiers as linear variables";
515 if (!isDeclareSimd) {
516 for (
Attribute attr : *linearModifiers) {
519 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
522 omp::LinearModifier mod = modAttr.getValue();
523 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
525 <<
"linear modifier '" << omp::stringifyLinearModifier(mod)
526 <<
"' may only be specified on a declare simd directive";
541 for (
const auto &it : nontemporalVars)
542 if (!nontemporalItems.insert(it).second)
543 return op->
emitOpError() <<
"nontemporal variable used more than once";
552 std::optional<ArrayAttr> alignments,
555 if (!alignedVars.empty()) {
556 if (!alignments || alignments->size() != alignedVars.size())
558 <<
"expected as many alignment values as aligned variables";
561 return op->
emitOpError() <<
"unexpected alignment values attribute";
567 for (
auto it : alignedVars)
568 if (!alignedItems.insert(it).second)
569 return op->
emitOpError() <<
"aligned variable used more than once";
575 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
576 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
577 if (intAttr.getValue().sle(0))
578 return op->
emitOpError() <<
"alignment should be greater than 0";
580 return op->
emitOpError() <<
"expected integer alignment";
597 if (parser.parseOperand(alignedVars.emplace_back()) ||
598 parser.parseColonType(alignedTypes.emplace_back()) ||
599 parser.parseArrow() ||
600 parser.parseAttribute(alignmentVec.emplace_back())) {
607 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
614 std::optional<ArrayAttr> alignments) {
615 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
618 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
619 p <<
" -> " << (*alignments)[i];
630 if (modifiers.size() > 2)
632 for (
const auto &mod : modifiers) {
635 auto symbol = symbolizeScheduleModifier(mod);
638 <<
" unknown modifier type: " << mod;
643 if (modifiers.size() == 1) {
644 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
645 modifiers.push_back(modifiers[0]);
646 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
648 }
else if (modifiers.size() == 2) {
651 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
652 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
654 <<
" incorrect modifier order";
670 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
671 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
676 std::optional<mlir::omp::ClauseScheduleKind> schedule =
677 symbolizeClauseScheduleKind(keyword);
681 scheduleAttr = ClauseScheduleKindAttr::get(parser.
getContext(), *schedule);
683 case ClauseScheduleKind::Static:
684 case ClauseScheduleKind::Dynamic:
685 case ClauseScheduleKind::Guided:
691 chunkSize = std::nullopt;
694 case ClauseScheduleKind::Auto:
695 case ClauseScheduleKind::Runtime:
696 case ClauseScheduleKind::Distribute:
697 chunkSize = std::nullopt;
706 modifiers.push_back(mod);
712 if (!modifiers.empty()) {
714 if (std::optional<ScheduleModifier> mod =
715 symbolizeScheduleModifier(modifiers[0])) {
716 scheduleMod = ScheduleModifierAttr::get(parser.
getContext(), *mod);
718 return parser.
emitError(loc,
"invalid schedule modifier");
721 if (modifiers.size() > 1) {
722 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
732 ClauseScheduleKindAttr scheduleKind,
733 ScheduleModifierAttr scheduleMod,
734 UnitAttr scheduleSimd,
Value scheduleChunk,
735 Type scheduleChunkType) {
736 p << stringifyClauseScheduleKind(scheduleKind.getValue());
738 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
740 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
752 ClauseOrderKindAttr &order,
753 OrderModifierAttr &orderMod) {
758 if (std::optional<OrderModifier> enumValue =
759 symbolizeOrderModifier(enumStr)) {
760 orderMod = OrderModifierAttr::get(parser.
getContext(), *enumValue);
767 if (std::optional<ClauseOrderKind> enumValue =
768 symbolizeClauseOrderKind(enumStr)) {
769 order = ClauseOrderKindAttr::get(parser.
getContext(), *enumValue);
772 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
776 ClauseOrderKindAttr order,
777 OrderModifierAttr orderMod) {
779 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
781 p << stringifyClauseOrderKind(order.getValue());
784template <
typename ClauseTypeAttr,
typename ClauseType>
787 std::optional<OpAsmParser::UnresolvedOperand> &operand,
789 std::optional<ClauseType> (*symbolizeClause)(StringRef),
790 StringRef clauseName) {
793 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
794 prescriptiveness = ClauseTypeAttr::get(parser.
getContext(), *enumValue);
799 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
809 <<
"expected " << clauseName <<
" operand";
812 if (operand.has_value()) {
820template <
typename ClauseTypeAttr,
typename ClauseType>
823 ClauseTypeAttr prescriptiveness,
Value operand,
825 StringRef (*stringifyClauseType)(ClauseType)) {
827 if (prescriptiveness)
828 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
831 p << operand <<
": " << operandType;
841 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
842 Type &grainsizeType) {
844 parser, grainsizeMod, grainsize, grainsizeType,
845 &symbolizeClauseGrainsizeType,
"grainsize");
849 ClauseGrainsizeTypeAttr grainsizeMod,
852 p, op, grainsizeMod, grainsize, grainsizeType,
853 &stringifyClauseGrainsizeType);
863 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
864 Type &numTasksType) {
866 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
871 ClauseNumTasksTypeAttr numTasksMod,
874 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
890 return mlir::failure();
891 inTypeAttr = TypeAttr::get(inType);
920 if (!typeparams.empty()) {
921 p <<
'(' << typeparams <<
" : " << typeparamsTypes <<
')';
923 for (
auto sh :
shape) {
935 FallbackModifierAttr fallback,
936 Value dynGroupprivateSize) {
937 if (!dynGroupprivateSize && (accessGroup || fallback))
938 return op->
emitOpError(
"dyn_groupprivate modifiers require a size operand");
944 OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr,
945 FallbackModifierAttr &fallbackAttr,
946 std::optional<OpAsmParser::UnresolvedOperand> &dynGroupprivateSize,
949 bool parsedAccessGroup =
false;
950 bool parsedFallback =
false;
951 bool parsedSize =
false;
956 if (parsedAccessGroup)
958 "duplicate access group modifier");
959 accessGroupAttr = AccessGroupModifierAttr::get(
960 parser.
getContext(), AccessGroupModifier::cgroup);
961 parsedAccessGroup =
true;
968 "duplicate fallback modifier");
971 "expected '(' after 'fallback'");
972 llvm::StringRef fbKind;
976 "expected fallback modifier (abort/null/default_mem)");
977 std::optional<FallbackModifier> fbEnum;
978 if (fbKind ==
"abort")
979 fbEnum = FallbackModifier::abort;
980 else if (fbKind ==
"null")
981 fbEnum = FallbackModifier::null;
982 else if (fbKind ==
"default_mem")
983 fbEnum = FallbackModifier::default_mem;
986 "invalid fallback modifier '" + fbKind +
"'");
987 fallbackAttr = FallbackModifierAttr::get(parser.
getContext(), *fbEnum);
990 "expected ')' after fallback modifier");
991 parsedFallback =
true;
999 "duplicate size operand");
1000 dynGroupprivateSize = operand;
1004 "expected ':' and type after size operand");
1008 "expected dyn_groupprivate_size operand");
1013 AccessGroupModifierAttr modifierFirst,
1014 FallbackModifierAttr modifierSecond,
1015 Value dynGroupprivateSize,
1018 bool needsComma =
false;
1020 if (modifierFirst) {
1021 printer << modifierFirst.getValue();
1025 if (modifierSecond) {
1028 printer <<
"fallback(";
1029 printer << modifierSecond.getValue();
1034 if (dynGroupprivateSize) {
1037 printer << dynGroupprivateSize <<
" : " << sizeType;
1046struct MapParseArgs {
1047 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1048 SmallVectorImpl<Type> &types;
1049 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1050 SmallVectorImpl<Type> &types)
1051 : vars(vars), types(types) {}
1053struct PrivateParseArgs {
1054 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1055 llvm::SmallVectorImpl<Type> &types;
1057 UnitAttr &needsBarrier;
1059 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1060 SmallVectorImpl<Type> &types,
ArrayAttr &syms,
1061 UnitAttr &needsBarrier,
1063 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1064 mapIndices(mapIndices) {}
1067struct ReductionParseArgs {
1068 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1069 SmallVectorImpl<Type> &types;
1072 ReductionModifierAttr *modifier;
1073 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1075 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
1076 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1079struct AllRegionParseArgs {
1080 std::optional<MapParseArgs> hasDeviceAddrArgs;
1081 std::optional<MapParseArgs> hostEvalArgs;
1082 std::optional<ReductionParseArgs> inReductionArgs;
1083 std::optional<MapParseArgs> mapArgs;
1084 std::optional<PrivateParseArgs> privateArgs;
1085 std::optional<ReductionParseArgs> reductionArgs;
1086 std::optional<ReductionParseArgs> taskReductionArgs;
1087 std::optional<MapParseArgs> useDeviceAddrArgs;
1088 std::optional<MapParseArgs> useDevicePtrArgs;
1093 return "private_barrier";
1103 ReductionModifierAttr *modifier =
nullptr,
1104 UnitAttr *needsBarrier =
nullptr) {
1108 unsigned regionArgOffset = regionPrivateArgs.size();
1118 std::optional<ReductionModifier> enumValue =
1119 symbolizeReductionModifier(enumStr);
1120 if (!enumValue.has_value())
1122 *modifier = ReductionModifierAttr::get(parser.
getContext(), *enumValue);
1129 isByRefVec.push_back(
1130 parser.parseOptionalKeyword(
"byref").succeeded());
1132 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
1135 if (parser.parseOperand(operands.emplace_back()) ||
1136 parser.parseArrow() ||
1137 parser.parseArgument(regionPrivateArgs.emplace_back()))
1141 if (parser.parseOptionalLSquare().succeeded()) {
1142 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
1143 parser.parseInteger(mapIndicesVec.emplace_back()) ||
1144 parser.parseRSquare())
1147 mapIndicesVec.push_back(-1);
1159 if (parser.parseType(types.emplace_back()))
1166 if (operands.size() != types.size())
1175 *needsBarrier = mlir::UnitAttr::get(parser.
getContext());
1178 auto *argsBegin = regionPrivateArgs.begin();
1180 argsBegin + regionArgOffset + types.size());
1181 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1187 *symbols = ArrayAttr::get(parser.
getContext(), symbolAttrs);
1190 if (!mapIndicesVec.empty())
1203 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1218 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1224 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1225 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
1226 nullptr, &privateArgs->needsBarrier)))
1235 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1240 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1241 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
1242 reductionArgs->modifier)))
1249 AllRegionParseArgs args) {
1253 args.hasDeviceAddrArgs)))
1255 <<
"invalid `has_device_addr` format";
1258 args.hostEvalArgs)))
1260 <<
"invalid `host_eval` format";
1263 args.inReductionArgs)))
1265 <<
"invalid `in_reduction` format";
1270 <<
"invalid `map_entries` format";
1275 <<
"invalid `private` format";
1278 args.reductionArgs)))
1280 <<
"invalid `reduction` format";
1283 args.taskReductionArgs)))
1285 <<
"invalid `task_reduction` format";
1288 args.useDeviceAddrArgs)))
1290 <<
"invalid `use_device_addr` format";
1293 args.useDevicePtrArgs)))
1295 <<
"invalid `use_device_addr` format";
1297 return parser.
parseRegion(region, entryBlockArgs);
1316 AllRegionParseArgs args;
1317 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1318 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1319 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1320 inReductionByref, inReductionSyms);
1321 args.mapArgs.emplace(mapVars, mapTypes);
1322 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1323 privateNeedsBarrier, &privateMaps);
1334 UnitAttr &privateNeedsBarrier) {
1335 AllRegionParseArgs args;
1336 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1337 inReductionByref, inReductionSyms);
1338 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1339 privateNeedsBarrier);
1350 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1354 AllRegionParseArgs args;
1355 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1356 inReductionByref, inReductionSyms);
1357 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1358 privateNeedsBarrier);
1359 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1360 reductionSyms, &reductionMod);
1368 UnitAttr &privateNeedsBarrier) {
1369 AllRegionParseArgs args;
1370 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1371 privateNeedsBarrier);
1379 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1383 AllRegionParseArgs args;
1384 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1385 privateNeedsBarrier);
1386 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1387 reductionSyms, &reductionMod);
1396 AllRegionParseArgs args;
1397 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1398 taskReductionByref, taskReductionSyms);
1408 AllRegionParseArgs args;
1409 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1410 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1419struct MapPrintArgs {
1424struct PrivatePrintArgs {
1428 UnitAttr needsBarrier;
1432 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1433 mapIndices(mapIndices) {}
1435struct ReductionPrintArgs {
1440 ReductionModifierAttr modifier;
1442 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
1443 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1445struct AllRegionPrintArgs {
1446 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1447 std::optional<MapPrintArgs> hostEvalArgs;
1448 std::optional<ReductionPrintArgs> inReductionArgs;
1449 std::optional<MapPrintArgs> mapArgs;
1450 std::optional<PrivatePrintArgs> privateArgs;
1451 std::optional<ReductionPrintArgs> reductionArgs;
1452 std::optional<ReductionPrintArgs> taskReductionArgs;
1453 std::optional<MapPrintArgs> useDeviceAddrArgs;
1454 std::optional<MapPrintArgs> useDevicePtrArgs;
1463 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
1464 if (argsSubrange.empty())
1467 p << clauseName <<
"(";
1470 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1474 symbols = ArrayAttr::get(ctx, values);
1487 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1488 mapIndices.asArrayRef(),
1489 byref.asArrayRef()),
1491 auto [op, arg, sym, map, isByRef] = t;
1497 p << op <<
" -> " << arg;
1500 p <<
" [map_idx=" << map <<
"]";
1503 llvm::interleaveComma(types, p);
1511 StringRef clauseName,
ValueRange argsSubrange,
1512 std::optional<MapPrintArgs> mapArgs) {
1519 StringRef clauseName,
ValueRange argsSubrange,
1520 std::optional<PrivatePrintArgs> privateArgs) {
1523 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1524 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1525 nullptr, privateArgs->needsBarrier);
1531 std::optional<ReductionPrintArgs> reductionArgs) {
1534 reductionArgs->vars, reductionArgs->types,
1535 reductionArgs->syms,
nullptr,
1536 reductionArgs->byref, reductionArgs->modifier);
1540 const AllRegionPrintArgs &args) {
1541 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1545 iface.getHasDeviceAddrBlockArgs(),
1546 args.hasDeviceAddrArgs);
1550 args.inReductionArgs);
1556 args.reductionArgs);
1558 iface.getTaskReductionBlockArgs(),
1559 args.taskReductionArgs);
1561 iface.getUseDeviceAddrBlockArgs(),
1562 args.useDeviceAddrArgs);
1564 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1580 AllRegionPrintArgs args;
1581 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1582 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1583 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1584 inReductionByref, inReductionSyms);
1585 args.mapArgs.emplace(mapVars, mapTypes);
1586 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1587 privateNeedsBarrier, privateMaps);
1595 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1596 AllRegionPrintArgs args;
1597 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1598 inReductionByref, inReductionSyms);
1599 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1600 privateNeedsBarrier,
1609 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1610 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1613 AllRegionPrintArgs args;
1614 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1615 inReductionByref, inReductionSyms);
1616 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1617 privateNeedsBarrier,
1619 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1620 reductionSyms, reductionMod);
1627 UnitAttr privateNeedsBarrier) {
1628 AllRegionPrintArgs args;
1629 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1630 privateNeedsBarrier,
1638 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1641 AllRegionPrintArgs args;
1642 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1643 privateNeedsBarrier,
1645 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1646 reductionSyms, reductionMod);
1656 AllRegionPrintArgs args;
1657 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1658 taskReductionByref, taskReductionSyms);
1668 AllRegionPrintArgs args;
1669 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1670 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1674template <
typename ParsePrefixFn>
1683 if (failed(parsePrefix()))
1691 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1692 iteratedVars.push_back(v);
1693 iteratedTypes.push_back(ty);
1695 plainVars.push_back(v);
1696 plainTypes.push_back(ty);
1702template <
typename Pr
intPrefixFn>
1706 PrintPrefixFn &&printPrefixForPlain,
1707 PrintPrefixFn &&printPrefixForIterated) {
1714 p << v <<
" : " << t;
1718 for (
unsigned i = 0; i < iteratedVars.size(); ++i)
1719 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1720 for (
unsigned i = 0; i < plainVars.size(); ++i)
1721 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1729 if (!reductionVars.empty()) {
1730 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1732 <<
"expected as many reduction symbol references "
1733 "as reduction variables";
1734 if (reductionByref && reductionByref->size() != reductionVars.size())
1735 return op->
emitError() <<
"expected as many reduction variable by "
1736 "reference attributes as reduction variables";
1739 return op->
emitOpError() <<
"unexpected reduction symbol references";
1746 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1747 Value accum = std::get<0>(args);
1749 if (!accumulators.insert(accum).second)
1750 return op->
emitOpError() <<
"accumulator variable used more than once";
1753 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1757 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1758 <<
" to point to a reduction declaration";
1760 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1762 <<
"expected accumulator (" << varType
1763 <<
") to be the same type as reduction declaration ("
1764 << decl.getAccumulatorType() <<
")";
1783 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1784 parser.parseArrow() ||
1785 parser.parseAttribute(symsVec.emplace_back()) ||
1786 parser.parseColonType(copyprivateTypes.emplace_back()))
1792 copyprivateSyms = ArrayAttr::get(parser.
getContext(), syms);
1800 std::optional<ArrayAttr> copyprivateSyms) {
1801 if (!copyprivateSyms.has_value())
1803 llvm::interleaveComma(
1804 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1805 [&](
const auto &args) {
1806 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1807 << std::get<2>(args);
1814 std::optional<ArrayAttr> copyprivateSyms) {
1815 size_t copyprivateSymsSize =
1816 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1817 if (copyprivateSymsSize != copyprivateVars.size())
1818 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1819 << copyprivateVars.size()
1820 <<
") and functions (= " << copyprivateSymsSize
1821 <<
"), both must be equal";
1822 if (!copyprivateSyms.has_value())
1825 for (
auto copyprivateVarAndSym :
1826 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1828 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1829 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1831 if (mlir::func::FuncOp mlirFuncOp =
1834 funcOp = mlirFuncOp;
1835 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1838 funcOp = llvmFuncOp;
1840 auto getNumArguments = [&] {
1841 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1844 auto getArgumentType = [&](
unsigned i) {
1845 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1850 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1851 <<
" to point to a copy function";
1853 if (getNumArguments() != 2)
1855 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1857 Type argTy = getArgumentType(0);
1858 if (argTy != getArgumentType(1))
1859 return op->
emitOpError() <<
"expected copy function " << symbolRef
1860 <<
" arguments to have the same type";
1862 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1863 if (argTy != varType)
1865 <<
"expected copy function arguments' type (" << argTy
1866 <<
") to be the same as copyprivate variable's type (" << varType
1891 OpAsmParser::UnresolvedOperand operand;
1893 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1894 parser.parseOperand(operand) || parser.parseColonType(ty))
1896 std::optional<ClauseTaskDepend> keywordDepend =
1897 symbolizeClauseTaskDepend(keyword);
1901 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1902 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1903 iteratedVars.push_back(operand);
1904 iteratedTypes.push_back(ty);
1905 iterKindsVec.push_back(kindAttr);
1907 dependVars.push_back(operand);
1908 dependTypes.push_back(ty);
1909 kindsVec.push_back(kindAttr);
1915 dependKinds = ArrayAttr::get(parser.
getContext(), kinds);
1917 iteratedKinds = ArrayAttr::get(parser.
getContext(), iterKinds);
1924 std::optional<ArrayAttr> dependKinds,
1927 std::optional<ArrayAttr> iteratedKinds) {
1930 std::optional<ArrayAttr> kinds) {
1931 for (
unsigned i = 0, e = vars.size(); i < e; ++i) {
1934 p << stringifyClauseTaskDepend(
1935 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1937 <<
" -> " << vars[i] <<
" : " << types[i];
1941 printEntries(dependVars, dependTypes, dependKinds);
1942 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1947 std::optional<ArrayAttr> dependKinds,
1949 std::optional<ArrayAttr> iteratedKinds,
1951 if (!dependVars.empty()) {
1952 if (!dependKinds || dependKinds->size() != dependVars.size())
1953 return op->
emitOpError() <<
"expected as many depend values"
1954 " as depend variables";
1956 if (dependKinds && !dependKinds->empty())
1957 return op->
emitOpError() <<
"unexpected depend values";
1960 if (!iteratedVars.empty()) {
1961 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1962 return op->
emitOpError() <<
"expected as many depend iterated values"
1963 " as depend iterated variables";
1965 if (iteratedKinds && !iteratedKinds->empty())
1966 return op->
emitOpError() <<
"unexpected depend iterated values";
1981 IntegerAttr &hintAttr) {
1982 StringRef hintKeyword;
1988 auto parseKeyword = [&]() -> ParseResult {
1991 if (hintKeyword ==
"uncontended")
1993 else if (hintKeyword ==
"contended")
1995 else if (hintKeyword ==
"nonspeculative")
1997 else if (hintKeyword ==
"speculative")
2001 << hintKeyword <<
" is not a valid hint";
2012 IntegerAttr hintAttr) {
2013 int64_t hint = hintAttr.getInt();
2021 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
2023 bool uncontended = bitn(hint, 0);
2024 bool contended = bitn(hint, 1);
2025 bool nonspeculative = bitn(hint, 2);
2026 bool speculative = bitn(hint, 3);
2030 hints.push_back(
"uncontended");
2032 hints.push_back(
"contended");
2034 hints.push_back(
"nonspeculative");
2036 hints.push_back(
"speculative");
2038 llvm::interleaveComma(hints, p);
2045 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
2047 bool uncontended = bitn(hint, 0);
2048 bool contended = bitn(hint, 1);
2049 bool nonspeculative = bitn(hint, 2);
2050 bool speculative = bitn(hint, 3);
2052 if (uncontended && contended)
2053 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
2054 "omp_sync_hint_contended cannot be combined";
2055 if (nonspeculative && speculative)
2056 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
2057 "omp_sync_hint_speculative cannot be combined.";
2068 return (value & flag) == flag;
2076static ParseResult parseMapClause(
OpAsmParser &parser,
2077 ClauseMapFlagsAttr &mapType) {
2078 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
2081 auto parseTypeAndMod = [&]() -> ParseResult {
2082 StringRef mapTypeMod;
2086 if (mapTypeMod ==
"always")
2087 mapTypeBits |= ClauseMapFlags::always;
2089 if (mapTypeMod ==
"implicit")
2090 mapTypeBits |= ClauseMapFlags::implicit;
2092 if (mapTypeMod ==
"ompx_hold")
2093 mapTypeBits |= ClauseMapFlags::ompx_hold;
2095 if (mapTypeMod ==
"close")
2096 mapTypeBits |= ClauseMapFlags::close;
2098 if (mapTypeMod ==
"present")
2099 mapTypeBits |= ClauseMapFlags::present;
2101 if (mapTypeMod ==
"to")
2102 mapTypeBits |= ClauseMapFlags::to;
2104 if (mapTypeMod ==
"from")
2105 mapTypeBits |= ClauseMapFlags::from;
2107 if (mapTypeMod ==
"tofrom")
2108 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
2110 if (mapTypeMod ==
"delete")
2111 mapTypeBits |= ClauseMapFlags::del;
2113 if (mapTypeMod ==
"storage")
2114 mapTypeBits |= ClauseMapFlags::storage;
2116 if (mapTypeMod ==
"return_param")
2117 mapTypeBits |= ClauseMapFlags::return_param;
2119 if (mapTypeMod ==
"private")
2120 mapTypeBits |= ClauseMapFlags::priv;
2122 if (mapTypeMod ==
"literal")
2123 mapTypeBits |= ClauseMapFlags::literal;
2125 if (mapTypeMod ==
"attach")
2126 mapTypeBits |= ClauseMapFlags::attach;
2128 if (mapTypeMod ==
"attach_always")
2129 mapTypeBits |= ClauseMapFlags::attach_always;
2131 if (mapTypeMod ==
"attach_never")
2132 mapTypeBits |= ClauseMapFlags::attach_never;
2134 if (mapTypeMod ==
"attach_auto")
2135 mapTypeBits |= ClauseMapFlags::attach_auto;
2137 if (mapTypeMod ==
"ref_ptr")
2138 mapTypeBits |= ClauseMapFlags::ref_ptr;
2140 if (mapTypeMod ==
"ref_ptee")
2141 mapTypeBits |= ClauseMapFlags::ref_ptee;
2143 if (mapTypeMod ==
"ref_ptr_ptee")
2144 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
2146 if (mapTypeMod ==
"is_device_ptr")
2147 mapTypeBits |= ClauseMapFlags::is_device_ptr;
2164 ClauseMapFlagsAttr mapType) {
2166 ClauseMapFlags mapFlags = mapType.getValue();
2171 mapTypeStrs.push_back(
"always");
2173 mapTypeStrs.push_back(
"implicit");
2175 mapTypeStrs.push_back(
"ompx_hold");
2177 mapTypeStrs.push_back(
"close");
2179 mapTypeStrs.push_back(
"present");
2188 mapTypeStrs.push_back(
"tofrom");
2190 mapTypeStrs.push_back(
"from");
2192 mapTypeStrs.push_back(
"to");
2195 mapTypeStrs.push_back(
"delete");
2197 mapTypeStrs.push_back(
"return_param");
2199 mapTypeStrs.push_back(
"storage");
2201 mapTypeStrs.push_back(
"private");
2203 mapTypeStrs.push_back(
"literal");
2205 mapTypeStrs.push_back(
"attach");
2207 mapTypeStrs.push_back(
"attach_always");
2209 mapTypeStrs.push_back(
"attach_never");
2211 mapTypeStrs.push_back(
"attach_auto");
2213 mapTypeStrs.push_back(
"ref_ptr");
2215 mapTypeStrs.push_back(
"ref_ptee");
2217 mapTypeStrs.push_back(
"ref_ptr_ptee");
2219 mapTypeStrs.push_back(
"is_device_ptr");
2220 if (mapFlags == ClauseMapFlags::none)
2221 mapTypeStrs.push_back(
"none");
2223 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2224 p << mapTypeStrs[i];
2225 if (i + 1 < mapTypeStrs.size()) {
2231static ParseResult parseMembersIndex(
OpAsmParser &parser,
2235 auto parseIndices = [&]() -> ParseResult {
2240 APInt(64, value,
false)));
2254 memberIdxs.push_back(ArrayAttr::get(parser.
getContext(), values));
2258 if (!memberIdxs.empty())
2259 membersIdx = ArrayAttr::get(parser.
getContext(), memberIdxs);
2269 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
2271 auto memberIdx = cast<ArrayAttr>(v);
2272 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
2273 p << cast<IntegerAttr>(v2).getInt();
2280 VariableCaptureKindAttr mapCaptureType) {
2281 std::string typeCapStr;
2282 llvm::raw_string_ostream typeCap(typeCapStr);
2283 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2285 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2286 typeCap <<
"ByCopy";
2287 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2288 typeCap <<
"VLAType";
2289 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2295 VariableCaptureKindAttr &mapCaptureType) {
2296 StringRef mapCaptureKey;
2300 if (mapCaptureKey ==
"This")
2301 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2302 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
2303 if (mapCaptureKey ==
"ByRef")
2304 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2305 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
2306 if (mapCaptureKey ==
"ByCopy")
2307 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2308 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2309 if (mapCaptureKey ==
"VLAType")
2310 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2311 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
2320 for (
auto mapOp : mapVars) {
2321 if (!mapOp.getDefiningOp())
2324 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2325 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2328 bool from =
mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2331 bool always =
mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2332 bool close =
mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2333 bool implicit =
mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2335 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2337 "to, from, tofrom and alloc map types are permitted");
2339 if (isa<TargetEnterDataOp>(op) && (from || del))
2340 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
2342 if (isa<TargetExitDataOp>(op) && to)
2344 "from, release and delete map types are permitted");
2346 if (isa<TargetUpdateOp>(op)) {
2349 "at least one of to or from map types must be "
2350 "specified, other map types are not permitted");
2355 "at least one of to or from map types must be "
2356 "specified, other map types are not permitted");
2359 auto updateVar = mapInfoOp.getVarPtr();
2361 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2362 (from && updateToVars.contains(updateVar))) {
2365 "either to or from map types can be specified, not both");
2368 if (always || close || implicit) {
2371 "present, mapper and iterator map type modifiers are permitted");
2374 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2376 }
else if (!isa<DeclareMapperInfoOp>(op)) {
2378 "map argument is not a map entry operation");
2385template <
typename OpType>
2389 std::optional<DenseI64ArrayAttr> privateMapIndices =
2390 targetOp.getPrivateMapsAttr();
2393 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2398 if (privateMapIndices.value().size() !=
2399 static_cast<int64_t>(privateVars.size()))
2400 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
2401 "`private_maps` attribute mismatch");
2411 StringRef clauseName,
2413 for (
Value var : vars)
2414 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2416 <<
"'" << clauseName
2417 <<
"' arguments must be defined by 'omp.map.info' ops";
2421LogicalResult MapInfoOp::verify() {
2422 if (getMapperId() &&
2424 *
this, getMapperIdAttr())) {
2439 const TargetDataOperands &clauses) {
2440 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2441 clauses.mapVars, clauses.useDeviceAddrVars,
2442 clauses.useDevicePtrVars);
2445LogicalResult TargetDataOp::verify() {
2446 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2447 getUseDeviceAddrVars().empty()) {
2448 return ::emitError(this->getLoc(),
2449 "At least one of map, use_device_ptr_vars, or "
2450 "use_device_addr_vars operand must be present");
2454 getUseDevicePtrVars())))
2458 getUseDeviceAddrVars())))
2468void TargetEnterDataOp::build(
2472 TargetEnterDataOp::build(
2474 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2475 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2479LogicalResult TargetEnterDataOp::verify() {
2480 LogicalResult verifyDependVars =
2482 getDependIteratedKinds(), getDependIterated());
2483 return failed(verifyDependVars) ? verifyDependVars
2494 TargetExitDataOp::build(
2496 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2497 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2501LogicalResult TargetExitDataOp::verify() {
2502 LogicalResult verifyDependVars =
2504 getDependIteratedKinds(), getDependIterated());
2505 return failed(verifyDependVars) ? verifyDependVars
2516 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
2519 clauses.dependIterated, clauses.device, clauses.ifExpr,
2520 clauses.mapVars, clauses.nowait);
2523LogicalResult TargetUpdateOp::verify() {
2524 LogicalResult verifyDependVars =
2526 getDependIteratedKinds(), getDependIterated());
2527 return failed(verifyDependVars) ? verifyDependVars
2536 const TargetOperands &clauses) {
2541 builder, state, {}, {}, clauses.bare,
2542 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2543 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2544 clauses.device, clauses.dynGroupprivateAccessGroup,
2545 clauses.dynGroupprivateFallback, clauses.dynGroupprivateSize,
2546 clauses.hasDeviceAddrVars, clauses.hostEvalVars, clauses.ifExpr,
2548 nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2549 clauses.nowait, clauses.privateVars,
2550 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2551 clauses.threadLimitVars,
2555LogicalResult TargetOp::verify() {
2557 getDependIteratedKinds(),
2558 getDependIterated())))
2562 getHasDeviceAddrVars())))
2569 *
this, getDynGroupprivateAccessGroupAttr(),
2570 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
2579LogicalResult TargetOp::verifyRegions() {
2580 auto teamsOps = getOps<TeamsOp>();
2581 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2582 return emitError(
"target containing multiple 'omp.teams' nested ops");
2585 bool hostEvalTripCount;
2586 Operation *capturedOp = getInnermostCapturedOmpOp();
2587 TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
2588 for (
Value hostEvalArg :
2589 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2591 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
2593 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2594 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2595 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2598 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
2599 "and 'thread_limit' in 'omp.teams'";
2601 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
2602 if (execMode == TargetExecMode::spmd &&
2603 parallelOp->isAncestor(capturedOp) &&
2604 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2608 <<
"host_eval argument only legal as 'num_threads' in "
2609 "'omp.parallel' when representing target SPMD";
2611 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2612 if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
2613 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2614 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2615 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2618 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2619 "and steps in 'omp.loop_nest' when trip count "
2620 "must be evaluated in the host";
2623 return emitOpError() <<
"host_eval argument illegal use in '"
2624 << user->getName() <<
"' operation";
2633 assert(rootOp &&
"expected valid operation");
2650 bool isOmpDialect = op->
getDialect() == ompDialect;
2652 if (!isOmpDialect || !hasRegions)
2659 if (checkSingleMandatoryExec) {
2664 if (successor->isReachable(parentBlock))
2667 for (
Block &block : *parentRegion)
2669 !domInfo.
dominates(parentBlock, &block))
2676 if (&sibling != op && !siblingAllowedFn(&sibling))
2689Operation *TargetOp::getInnermostCapturedOmpOp() {
2690 auto *ompDialect =
getContext()->getLoadedDialect<omp::OpenMPDialect>();
2702 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2705 memOp.getEffects(effects);
2706 return !llvm::any_of(
2708 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2709 isa<SideEffects::AutomaticAllocationScopeResource>(
2719 WsloopOp *wsLoopOp) {
2721 if (!teamsOp.getNumTeamsUpperVars().empty())
2725 if (teamsOp.getNumReductionVars())
2727 if (wsLoopOp->getNumReductionVars())
2731 OffloadModuleInterface offloadMod =
2735 auto ompFlags = offloadMod.getFlags();
2738 return ompFlags.getAssumeTeamsOversubscription() &&
2739 ompFlags.getAssumeThreadsOversubscription();
2742TargetExecMode TargetOp::getKernelExecFlags(
Operation *capturedOp,
2743 bool *hostEvalTripCount) {
2749 assert((!capturedOp ||
2750 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2751 "unexpected captured op");
2753 if (hostEvalTripCount)
2754 *hostEvalTripCount =
false;
2757 if (!isa_and_present<LoopNestOp>(capturedOp))
2758 return TargetExecMode::generic;
2762 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2763 assert(!loopWrappers.empty());
2765 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2766 if (isa<SimdOp>(innermostWrapper))
2767 innermostWrapper = std::next(innermostWrapper);
2769 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2770 if (numWrappers != 1 && numWrappers != 2)
2771 return TargetExecMode::generic;
2774 if (numWrappers == 2) {
2775 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2777 return TargetExecMode::generic;
2779 innermostWrapper = std::next(innermostWrapper);
2780 if (!isa<DistributeOp>(innermostWrapper))
2781 return TargetExecMode::generic;
2784 if (!isa_and_present<ParallelOp>(parallelOp))
2785 return TargetExecMode::generic;
2787 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->
getParentOp());
2789 return TargetExecMode::generic;
2791 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2792 TargetExecMode
result = TargetExecMode::spmd;
2794 result = TargetExecMode::no_loop;
2795 if (hostEvalTripCount)
2796 *hostEvalTripCount =
true;
2801 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2803 if (!isa_and_present<TeamsOp>(teamsOp))
2804 return TargetExecMode::generic;
2806 if (teamsOp->
getParentOp() != targetOp.getOperation())
2807 return TargetExecMode::generic;
2809 if (hostEvalTripCount)
2810 *hostEvalTripCount =
true;
2812 if (isa<LoopOp>(innermostWrapper))
2813 return TargetExecMode::spmd;
2815 return TargetExecMode::generic;
2818 else if (isa<WsloopOp>(innermostWrapper)) {
2820 if (!isa_and_present<ParallelOp>(parallelOp))
2821 return TargetExecMode::generic;
2823 if (parallelOp->
getParentOp() == targetOp.getOperation())
2824 return TargetExecMode::spmd;
2827 return TargetExecMode::generic;
2836 ParallelOp::build(builder, state,
ValueRange(),
2848 const ParallelOperands &clauses) {
2850 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2851 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2853 clauses.privateNeedsBarrier, clauses.procBindKind,
2854 clauses.reductionMod, clauses.reductionVars,
2859template <
typename OpType>
2861 auto privateVars = op.getPrivateVars();
2862 auto privateSyms = op.getPrivateSymsAttr();
2864 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2867 auto numPrivateVars = privateVars.size();
2868 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2870 if (numPrivateVars != numPrivateSyms)
2871 return op.emitError() <<
"inconsistent number of private variables and "
2872 "privatizer op symbols, private vars: "
2874 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2876 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2877 Type varType = std::get<0>(privateVarInfo).getType();
2878 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2879 PrivateClauseOp privatizerOp =
2882 if (privatizerOp ==
nullptr)
2883 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2884 << privateSym <<
"'";
2886 Type privatizerType = privatizerOp.getArgType();
2888 if (privatizerType && (varType != privatizerType))
2889 return op.emitError()
2890 <<
"type mismatch between a "
2891 << (privatizerOp.getDataSharingType() ==
2892 DataSharingClauseType::Private
2895 <<
" variable and its privatizer op, var type: " << varType
2896 <<
" vs. privatizer op type: " << privatizerType;
2902LogicalResult ParallelOp::verify() {
2903 if (getAllocateVars().size() != getAllocatorVars().size())
2905 "expected equal sizes for allocate and allocator variables");
2911 getReductionByref());
2914LogicalResult ParallelOp::verifyRegions() {
2915 auto distChildOps = getOps<DistributeOp>();
2916 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2917 if (numDistChildOps > 1)
2919 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2921 if (numDistChildOps == 1) {
2924 <<
"'omp.composite' attribute missing from composite operation";
2926 auto *ompDialect =
getContext()->getLoadedDialect<OpenMPDialect>();
2927 Operation &distributeOp = **distChildOps.begin();
2929 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2933 return emitError() <<
"unexpected OpenMP operation inside of composite "
2935 << childOp.getName();
2937 }
else if (isComposite()) {
2939 <<
"'omp.composite' attribute present in non-composite operation";
2956 const TeamsOperands &clauses) {
2960 builder, state, clauses.allocateVars, clauses.allocatorVars,
2961 clauses.dynGroupprivateAccessGroup, clauses.dynGroupprivateFallback,
2962 clauses.dynGroupprivateSize, clauses.ifExpr, clauses.numTeamsLower,
2963 clauses.numTeamsUpperVars, {},
nullptr,
2964 nullptr, clauses.reductionMod,
2965 clauses.reductionVars,
2967 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2974 if (numTeamsLower) {
2975 if (numTeamsUpperVars.size() != 1)
2977 "expected exactly one num_teams upper bound when lower bound is "
2981 "expected num_teams upper bound and lower bound to be "
2988LogicalResult TeamsOp::verify() {
2997 return emitError(
"expected to be nested inside of omp.target or not nested "
2998 "in any OpenMP dialect operations");
3002 this->getNumTeamsUpperVars())))
3006 if (getAllocateVars().size() != getAllocatorVars().size())
3008 "expected equal sizes for allocate and allocator variables");
3011 op, getDynGroupprivateAccessGroupAttr(),
3012 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
3019 getReductionByref());
3027 return getParentOp().getPrivateVars();
3031 return getParentOp().getReductionVars();
3039 const SectionsOperands &clauses) {
3042 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3045 clauses.reductionMod, clauses.reductionVars,
3050LogicalResult SectionsOp::verify() {
3051 if (getAllocateVars().size() != getAllocatorVars().size())
3053 "expected equal sizes for allocate and allocator variables");
3056 getReductionByref());
3059LogicalResult SectionsOp::verifyRegions() {
3060 for (
auto &inst : *getRegion().begin()) {
3061 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
3063 <<
"expected omp.section op or terminator op inside region";
3075 const ScopeOperands &clauses) {
3077 ScopeOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3078 clauses.nowait, clauses.privateVars,
3080 clauses.privateNeedsBarrier, clauses.reductionMod,
3081 clauses.reductionVars,
3086LogicalResult ScopeOp::verify() {
3087 if (getAllocateVars().size() != getAllocatorVars().size())
3089 "expected equal sizes for allocate and allocator variables");
3095 getReductionByref());
3103 const SingleOperands &clauses) {
3106 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3107 clauses.copyprivateVars,
3108 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
3113LogicalResult SingleOp::verify() {
3115 if (getAllocateVars().size() != getAllocatorVars().size())
3117 "expected equal sizes for allocate and allocator variables");
3120 getCopyprivateSyms());
3128 const WorkshareOperands &clauses) {
3129 WorkshareOp::build(builder, state, clauses.nowait);
3136LogicalResult WorkshareLoopWrapperOp::verify() {
3137 if (!(*this)->getParentOfType<WorkshareOp>())
3138 return emitOpError() <<
"must be nested in an omp.workshare";
3142LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
3143 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3145 return emitOpError() <<
"expected to be a standalone loop wrapper";
3154LogicalResult LoopWrapperInterface::verifyImpl() {
3158 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
3159 "and `SingleBlock` traits";
3162 return emitOpError() <<
"loop wrapper does not contain exactly one region";
3165 if (range_size(region.
getOps()) != 1)
3167 <<
"loop wrapper does not contain exactly one nested op";
3170 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
3171 return emitOpError() <<
"nested in loop wrapper is not another loop "
3172 "wrapper or `omp.loop_nest`";
3182 const LoopOperands &clauses) {
3185 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
3187 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
3188 clauses.reductionMod, clauses.reductionVars,
3193LogicalResult LoopOp::verify() {
3198 getReductionByref());
3201LogicalResult LoopOp::verifyRegions() {
3202 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3204 return emitOpError() <<
"expected to be a standalone loop wrapper";
3215 build(builder, state, {}, {},
3218 false,
nullptr,
nullptr,
3219 nullptr, {},
nullptr,
3230 const WsloopOperands &clauses) {
3235 {}, {}, clauses.linearVars,
3236 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3237 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3238 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3239 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3241 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3242 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3245LogicalResult WsloopOp::verify() {
3249 if (getLinearVars().size() &&
3250 getLinearVarTypes().value().size() != getLinearVars().size())
3251 return emitError() <<
"Ill-formed type attributes for linear variables";
3257 getReductionByref());
3260LogicalResult WsloopOp::verifyRegions() {
3261 bool isCompositeChildLeaf =
3262 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3264 if (LoopWrapperInterface nested = getNestedWrapper()) {
3267 <<
"'omp.composite' attribute missing from composite wrapper";
3271 if (!isa<SimdOp>(nested))
3272 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3274 }
else if (isComposite() && !isCompositeChildLeaf) {
3276 <<
"'omp.composite' attribute present in non-composite wrapper";
3277 }
else if (!isComposite() && isCompositeChildLeaf) {
3279 <<
"'omp.composite' attribute missing from composite wrapper";
3290 const SimdOperands &clauses) {
3292 SimdOp::build(builder, state, clauses.alignedVars,
3294 clauses.linearVars, clauses.linearStepVars,
3295 clauses.linearVarTypes, clauses.linearModifiers,
3296 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3297 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3298 clauses.privateNeedsBarrier, clauses.reductionMod,
3299 clauses.reductionVars,
3305LogicalResult SimdOp::verify() {
3306 if (getSimdlen().has_value() && getSafelen().has_value() &&
3307 getSimdlen().value() > getSafelen().value())
3309 <<
"simdlen clause and safelen clause are both present, but the "
3310 "simdlen value is not less than or equal to safelen value";
3322 bool isCompositeChildLeaf =
3323 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3325 if (!isComposite() && isCompositeChildLeaf)
3327 <<
"'omp.composite' attribute missing from composite wrapper";
3329 if (isComposite() && !isCompositeChildLeaf)
3331 <<
"'omp.composite' attribute present in non-composite wrapper";
3335 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3337 for (
const Attribute &sym : *privateSyms) {
3338 auto symRef = cast<SymbolRefAttr>(sym);
3339 omp::PrivateClauseOp privatizer =
3341 getOperation(), symRef);
3343 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
3344 if (privatizer.getDataSharingType() ==
3345 DataSharingClauseType::FirstPrivate)
3346 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
3353 if (getLinearVars().size() &&
3354 getLinearVarTypes().value().size() != getLinearVars().size())
3355 return emitError() <<
"Ill-formed type attributes for linear variables";
3359LogicalResult SimdOp::verifyRegions() {
3360 if (getNestedWrapper())
3361 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
3371 const DistributeOperands &clauses) {
3372 DistributeOp::build(builder, state, clauses.allocateVars,
3373 clauses.allocatorVars, clauses.distScheduleStatic,
3374 clauses.distScheduleChunkSize, clauses.order,
3375 clauses.orderMod, clauses.privateVars,
3377 clauses.privateNeedsBarrier);
3380LogicalResult DistributeOp::verify() {
3381 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3383 "dist_schedule_static being present";
3385 if (getAllocateVars().size() != getAllocatorVars().size())
3387 "expected equal sizes for allocate and allocator variables");
3395LogicalResult DistributeOp::verifyRegions() {
3396 if (LoopWrapperInterface nested = getNestedWrapper()) {
3399 <<
"'omp.composite' attribute missing from composite wrapper";
3402 if (isa<WsloopOp>(nested)) {
3404 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3405 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3406 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
3407 "when a composite 'omp.parallel' is the direct "
3410 }
else if (!isa<SimdOp>(nested))
3411 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
3413 }
else if (isComposite()) {
3415 <<
"'omp.composite' attribute present in non-composite wrapper";
3425LogicalResult DeclareMapperInfoOp::verify() {
3429LogicalResult DeclareMapperOp::verifyRegions() {
3430 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3431 getRegion().getBlocks().front().getTerminator()))
3432 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
3441LogicalResult DeclareReductionOp::verifyRegions() {
3442 if (!getAllocRegion().empty()) {
3443 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3444 if (yieldOp.getResults().size() != 1 ||
3445 yieldOp.getResults().getTypes()[0] !=
getType())
3446 return emitOpError() <<
"expects alloc region to yield a value "
3447 "of the reduction type";
3451 if (getInitializerRegion().empty())
3452 return emitOpError() <<
"expects non-empty initializer region";
3453 Block &initializerEntryBlock = getInitializerRegion().
front();
3456 if (!getAllocRegion().empty())
3457 return emitOpError() <<
"expects two arguments to the initializer region "
3458 "when an allocation region is used";
3460 if (getAllocRegion().empty())
3461 return emitOpError() <<
"expects one argument to the initializer region "
3462 "when no allocation region is used";
3465 <<
"expects one or two arguments to the initializer region";
3469 if (arg.getType() !=
getType())
3470 return emitOpError() <<
"expects initializer region argument to match "
3471 "the reduction type";
3473 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3474 if (yieldOp.getResults().size() != 1 ||
3475 yieldOp.getResults().getTypes()[0] !=
getType())
3476 return emitOpError() <<
"expects initializer region to yield a value "
3477 "of the reduction type";
3480 if (getReductionRegion().empty())
3481 return emitOpError() <<
"expects non-empty reduction region";
3482 Block &reductionEntryBlock = getReductionRegion().
front();
3487 return emitOpError() <<
"expects reduction region with two arguments of "
3488 "the reduction type";
3489 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3490 if (yieldOp.getResults().size() != 1 ||
3491 yieldOp.getResults().getTypes()[0] !=
getType())
3492 return emitOpError() <<
"expects reduction region to yield a value "
3493 "of the reduction type";
3496 if (!getAtomicReductionRegion().empty()) {
3497 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
3501 return emitOpError() <<
"expects atomic reduction region with two "
3502 "arguments of the same type";
3503 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3506 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
3507 return emitOpError() <<
"expects atomic reduction region arguments to "
3508 "be accumulators containing the reduction type";
3511 if (getCleanupRegion().empty())
3513 Block &cleanupEntryBlock = getCleanupRegion().
front();
3516 return emitOpError() <<
"expects cleanup region with one argument "
3517 "of the reduction type";
3527 const TaskOperands &clauses) {
3530 builder, state, clauses.iterated, clauses.affinityVars,
3531 clauses.allocateVars, clauses.allocatorVars,
3532 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3533 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3534 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3536 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3537 clauses.priority, clauses.privateVars,
3539 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3542LogicalResult TaskOp::verify() {
3543 LogicalResult verifyDependVars =
3545 getDependIteratedKinds(), getDependIterated());
3546 if (
failed(verifyDependVars))
3547 return verifyDependVars;
3553 getInReductionVars(), getInReductionByref());
3561 const TaskgroupOperands &clauses) {
3563 TaskgroupOp::build(builder, state, clauses.allocateVars,
3564 clauses.allocatorVars, clauses.taskReductionVars,
3569LogicalResult TaskgroupOp::verify() {
3571 getTaskReductionVars(),
3572 getTaskReductionByref());
3580 const TaskloopContextOperands &clauses) {
3582 TaskloopContextOp::build(
3583 builder, state, clauses.allocateVars, clauses.allocatorVars,
3584 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3585 clauses.inReductionVars,
3587 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3588 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3589 clauses.privateVars,
3591 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3596TaskloopWrapperOp TaskloopContextOp::getLoopOp() {
3597 return cast<TaskloopWrapperOp>(
3599 return isa<TaskloopWrapperOp>(op);
3603LogicalResult TaskloopContextOp::verify() {
3604 if (getAllocateVars().size() != getAllocatorVars().size())
3606 "expected equal sizes for allocate and allocator variables");
3612 getReductionVars(), getReductionByref())) ||
3614 getInReductionVars(),
3615 getInReductionByref())))
3618 if (!getReductionVars().empty() && getNogroup())
3619 return emitError(
"if a reduction clause is present on the taskloop "
3620 "directive, the nogroup clause must not be specified");
3621 for (
auto var : getReductionVars()) {
3622 if (llvm::is_contained(getInReductionVars(), var))
3623 return emitError(
"the same list item cannot appear in both a reduction "
3624 "and an in_reduction clause");
3627 if (getGrainsize() && getNumTasks()) {
3629 "the grainsize clause and num_tasks clause are mutually exclusive and "
3630 "may not appear on the same taskloop directive");
3636LogicalResult TaskloopContextOp::verifyRegions() {
3637 Region ®ion = getRegion();
3639 return emitOpError() <<
"expected non-empty region";
3642 return isa<TaskloopWrapperOp>(op);
3646 <<
"expected exactly 1 TaskloopWrapperOp directly nested in "
3648 << count <<
" were found";
3649 TaskloopWrapperOp loopWrapperOp = getLoopOp();
3651 auto loopNestOp = dyn_cast<LoopNestOp>(loopWrapperOp.getWrappedLoop());
3657 std::function<
bool(
Value)> isValidBoundValue = [&](
Value value) ->
bool {
3658 Region *valueRegion = value.getParentRegion();
3664 Operation *defOp = value.getDefiningOp();
3668 return llvm::all_of(defOp->
getOperands(), isValidBoundValue);
3670 auto hasUnsupportedTaskloopLocalBound = [&](
OperandRange range) ->
bool {
3671 return llvm::any_of(range,
3672 [&](
Value value) {
return !isValidBoundValue(value); });
3675 if (hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopLowerBounds()) ||
3676 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopUpperBounds()) ||
3677 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopSteps())) {
3679 <<
"expects loop bounds and steps to be defined outside of the "
3680 "taskloop.context region or by pure, regionless operations "
3681 "that do not depend on block arguments";
3692 const TaskloopWrapperOperands &clauses) {
3693 TaskloopWrapperOp::build(builder, state);
3696TaskloopContextOp TaskloopWrapperOp::getTaskloopContext() {
3697 return dyn_cast<TaskloopContextOp>(getOperation()->getParentOp());
3700LogicalResult TaskloopWrapperOp::verify() {
3701 TaskloopContextOp context = getTaskloopContext();
3703 return emitOpError() <<
"expected to be nested in a taskloop context op";
3707LogicalResult TaskloopWrapperOp::verifyRegions() {
3708 if (LoopWrapperInterface nested = getNestedWrapper()) {
3711 <<
"'omp.composite' attribute missing from composite wrapper";
3715 if (!isa<SimdOp>(nested))
3716 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3717 }
else if (isComposite()) {
3719 <<
"'omp.composite' attribute present in non-composite wrapper";
3743 for (
auto &iv : ivs)
3744 iv.type = loopVarType;
3749 result.addAttribute(
"loop_inclusive", UnitAttr::get(ctx));
3765 "collapse_num_loops",
3770 auto parseTiles = [&]() -> ParseResult {
3774 tiles.push_back(
tile);
3783 if (tiles.size() > 0)
3802 Region ®ion = getRegion();
3804 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3805 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3806 if (getLoopInclusive())
3808 p <<
"step (" << getLoopSteps() <<
") ";
3809 if (
int64_t numCollapse = getCollapseNumLoops())
3810 if (numCollapse > 1)
3811 p <<
"collapse(" << numCollapse <<
") ";
3814 p <<
"tiles(" << tiles.value() <<
") ";
3820 const LoopNestOperands &clauses) {
3822 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3823 clauses.loopLowerBounds, clauses.loopUpperBounds,
3824 clauses.loopSteps, clauses.loopInclusive,
3828LogicalResult LoopNestOp::verify() {
3829 if (getLoopLowerBounds().empty())
3830 return emitOpError() <<
"must represent at least one loop";
3832 if (getLoopLowerBounds().size() != getIVs().size())
3833 return emitOpError() <<
"number of range arguments and IVs do not match";
3835 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3836 if (lb.getType() != iv.getType())
3838 <<
"range argument type does not match corresponding IV type";
3841 uint64_t numIVs = getIVs().size();
3843 if (
const auto &numCollapse = getCollapseNumLoops())
3844 if (numCollapse > numIVs)
3846 <<
"collapse value is larger than the number of loops";
3849 if (tiles.value().size() > numIVs)
3850 return emitOpError() <<
"too few canonical loops for tile dimensions";
3852 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3853 return emitOpError() <<
"expects parent op to be a loop wrapper";
3858void LoopNestOp::gatherWrappers(
3861 while (
auto wrapper =
3862 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3863 wrappers.push_back(wrapper);
3872std::tuple<NewCliOp, OpOperand *, OpOperand *>
3878 return {{},
nullptr,
nullptr};
3881 "Unexpected type of cli");
3887 auto op = cast<LoopTransformationInterface>(use.getOwner());
3889 unsigned opnum = use.getOperandNumber();
3890 if (op.isGeneratee(opnum)) {
3891 assert(!gen &&
"Each CLI may have at most one def");
3893 }
else if (op.isApplyee(opnum)) {
3894 assert(!cons &&
"Each CLI may have at most one consumer");
3897 llvm_unreachable(
"Unexpected operand for a CLI");
3901 return {create, gen, cons};
3924 std::string cliName{
"cli"};
3928 .Case([&](CanonicalLoopOp op) {
3931 .Case([&](UnrollHeuristicOp op) -> std::string {
3932 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3934 .Case([&](FuseOp op) -> std::string {
3935 unsigned opnum =
generator->getOperandNumber();
3938 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3939 return "canonloop_fuse";
3943 .Case([&](TileOp op) -> std::string {
3944 auto [generateesFirst, generateesCount] =
3945 op.getGenerateesODSOperandIndexAndLength();
3946 unsigned firstGrid = generateesFirst;
3947 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3948 unsigned end = generateesFirst + generateesCount;
3949 unsigned opnum =
generator->getOperandNumber();
3951 if (firstGrid <= opnum && opnum < firstIntratile) {
3952 unsigned gridnum = opnum - firstGrid + 1;
3953 return (
"grid" + Twine(gridnum)).str();
3955 if (firstIntratile <= opnum && opnum < end) {
3956 unsigned intratilenum = opnum - firstIntratile + 1;
3957 return (
"intratile" + Twine(intratilenum)).str();
3959 llvm_unreachable(
"Unexpected generatee argument");
3961 .DefaultUnreachable(
"TODO: Custom name for this operation");
3964 setNameFn(
result, cliName);
3967LogicalResult NewCliOp::verify() {
3968 Value cli = getResult();
3971 "Unexpected type of cli");
3977 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3979 unsigned opnum = use.getOperandNumber();
3980 if (op.isGeneratee(opnum)) {
3983 emitOpError(
"CLI must have at most one generator");
3985 .
append(
"first generator here:");
3987 .
append(
"second generator here:");
3992 }
else if (op.isApplyee(opnum)) {
3995 emitOpError(
"CLI must have at most one consumer");
3997 .
append(
"first consumer here:")
4001 .
append(
"second consumer here:")
4008 llvm_unreachable(
"Unexpected operand for a CLI");
4016 .
append(
"see consumer here: ")
4039 setNameFn(&getRegion().front(),
"body_entry");
4042void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
4050 p <<
'(' << getCli() <<
')';
4051 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
4052 <<
" in range(" << getTripCount() <<
") ";
4062 CanonicalLoopInfoType cliType =
4063 CanonicalLoopInfoType::get(parser.
getContext());
4088 if (parser.
parseRegion(*region, {inductionVariable}))
4093 result.operands.append(cliOperand);
4099 return mlir::success();
4102LogicalResult CanonicalLoopOp::verify() {
4105 if (!getRegion().empty()) {
4106 Region ®ion = getRegion();
4109 "Canonical loop region must have exactly one argument");
4113 "Region argument must be the same type as the trip count");
4119Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
4121std::pair<unsigned, unsigned>
4122CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
4127std::pair<unsigned, unsigned>
4128CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
4129 return getODSOperandIndexAndLength(odsIndex_cli);
4143 p <<
'(' << getApplyee() <<
')';
4150 auto cliType = CanonicalLoopInfoType::get(parser.
getContext());
4173 return mlir::success();
4176std::pair<unsigned, unsigned>
4177UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
4178 return getODSOperandIndexAndLength(odsIndex_applyee);
4181std::pair<unsigned, unsigned>
4182UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
4193 if (!generatees.empty())
4194 p <<
'(' << llvm::interleaved(generatees) <<
')';
4196 if (!applyees.empty())
4197 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4239 bool isOnlyCanonLoops =
true;
4241 for (
Value applyee : op.getApplyees()) {
4242 auto [create, gen, cons] =
decodeCli(applyee);
4245 return op.emitOpError() <<
"applyee CLI has no generator";
4247 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4248 canonLoops.push_back(loop);
4250 isOnlyCanonLoops =
false;
4255 if (!isOnlyCanonLoops)
4259 for (
auto i : llvm::seq<int>(1, canonLoops.size())) {
4260 auto parentLoop = canonLoops[i - 1];
4261 auto loop = canonLoops[i];
4263 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
4264 return op.emitOpError()
4265 <<
"tiled loop nest must be nested within each other";
4267 parentIVs.insert(parentLoop.getInductionVar());
4272 bool isPerfectlyNested = [&]() {
4273 auto &parentBody = parentLoop.getRegion();
4274 if (!parentBody.hasOneBlock())
4276 auto &parentBlock = parentBody.getBlocks().front();
4278 auto nestedLoopIt = parentBlock.begin();
4279 if (nestedLoopIt == parentBlock.end() ||
4280 (&*nestedLoopIt != loop.getOperation()))
4283 auto termIt = std::next(nestedLoopIt);
4284 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
4287 if (std::next(termIt) != parentBlock.end())
4292 if (!isPerfectlyNested)
4293 return op.emitOpError() <<
"tiled loop nest must be perfectly nested";
4295 if (parentIVs.contains(loop.getTripCount()))
4296 return op.emitOpError() <<
"tiled loop nest must be rectangular";
4313LogicalResult TileOp::verify() {
4314 if (getApplyees().empty())
4315 return emitOpError() <<
"must apply to at least one loop";
4317 if (getSizes().size() != getApplyees().size())
4318 return emitOpError() <<
"there must be one tile size for each applyee";
4320 if (!getGeneratees().empty() &&
4321 2 * getSizes().size() != getGeneratees().size())
4323 <<
"expecting two times the number of generatees than applyees";
4328std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4329 return getODSOperandIndexAndLength(odsIndex_applyees);
4332std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4333 return getODSOperandIndexAndLength(odsIndex_generatees);
4343 if (!generatees.empty())
4344 p <<
'(' << llvm::interleaved(generatees) <<
')';
4346 if (!applyees.empty())
4347 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4350LogicalResult FuseOp::verify() {
4351 if (getApplyees().size() < 2)
4352 return emitOpError() <<
"must apply to at least two loops";
4354 if (getFirst().has_value() && getCount().has_value()) {
4355 int64_t first = getFirst().value();
4356 int64_t count = getCount().value();
4357 if ((
unsigned)(first + count - 1) > getApplyees().size())
4358 return emitOpError() <<
"the numbers of applyees must be at least first "
4359 "minus one plus count attributes";
4360 if (!getGeneratees().empty() &&
4361 getGeneratees().size() != getApplyees().size() + 1 - count)
4362 return emitOpError() <<
"the number of generatees must be the number of "
4363 "aplyees plus one minus count";
4366 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4368 <<
"in a complete fuse the number of generatees must be exactly 1";
4370 for (
auto &&applyee : getApplyees()) {
4371 auto [create, gen, cons] =
decodeCli(applyee);
4374 return emitOpError() <<
"applyee CLI has no generator";
4375 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4378 <<
"currently only supports omp.canonical_loop as applyee";
4382std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4383 return getODSOperandIndexAndLength(odsIndex_applyees);
4386std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4387 return getODSOperandIndexAndLength(odsIndex_generatees);
4395 const CriticalDeclareOperands &clauses) {
4396 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4399LogicalResult CriticalDeclareOp::verify() {
4404 if (getNameAttr()) {
4405 SymbolRefAttr symbolRef = getNameAttr();
4409 return emitOpError() <<
"expected symbol reference " << symbolRef
4410 <<
" to point to a critical declaration";
4430 return op.
emitOpError() <<
"must be nested inside of a loop";
4434 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4435 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4437 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
4438 "have an ordered clause";
4440 if (hasRegion && orderedAttr.getInt() != 0)
4441 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
4442 "have a parameter present";
4444 if (!hasRegion && orderedAttr.getInt() == 0)
4445 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
4446 "have a parameter present";
4447 }
else if (!isa<SimdOp>(wrapper)) {
4448 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
4449 "or worksharing simd loop";
4455 const OrderedOperands &clauses) {
4456 OrderedOp::build(builder, state, clauses.doacrossDependType,
4457 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4460LogicalResult OrderedOp::verify() {
4464 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4465 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4466 return emitOpError() <<
"number of variables in depend clause does not "
4467 <<
"match number of iteration variables in the "
4474 const OrderedRegionOperands &clauses) {
4475 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4485 const TaskwaitOperands &clauses) {
4487 TaskwaitOp::build(builder, state,
nullptr,
4496LogicalResult AtomicReadOp::verify() {
4497 if (verifyCommon().
failed())
4498 return mlir::failure();
4500 if (
auto mo = getMemoryOrder()) {
4501 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4502 *mo == ClauseMemoryOrderKind::Release) {
4504 "memory-order must not be acq_rel or release for atomic reads");
4514LogicalResult AtomicWriteOp::verify() {
4515 if (verifyCommon().
failed())
4516 return mlir::failure();
4518 if (
auto mo = getMemoryOrder()) {
4519 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4520 *mo == ClauseMemoryOrderKind::Acquire) {
4522 "memory-order must not be acq_rel or acquire for atomic writes");
4532LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4538 if (
Value writeVal = op.getWriteOpVal()) {
4540 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4546LogicalResult AtomicUpdateOp::verify() {
4547 if (verifyCommon().
failed())
4548 return mlir::failure();
4550 if (
auto mo = getMemoryOrder()) {
4551 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4552 *mo == ClauseMemoryOrderKind::Acquire) {
4554 "memory-order must not be acq_rel or acquire for atomic updates");
4561LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4567AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4568 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4570 return dyn_cast<AtomicReadOp>(getSecondOp());
4573AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4574 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4576 return dyn_cast<AtomicWriteOp>(getSecondOp());
4579AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4580 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4582 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4585LogicalResult AtomicCaptureOp::verify() {
4589LogicalResult AtomicCaptureOp::verifyRegions() {
4590 if (verifyRegionsCommon().
failed())
4591 return mlir::failure();
4593 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
4595 "operations inside capture region must not have hint clause");
4597 if (getFirstOp()->getAttr(
"memory_order") ||
4598 getSecondOp()->getAttr(
"memory_order"))
4600 "operations inside capture region must not have memory_order clause");
4609 const CancelOperands &clauses) {
4610 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4623LogicalResult CancelOp::verify() {
4624 ClauseCancellationConstructType cct = getCancelDirective();
4627 if (!structuralParent)
4628 return emitOpError() <<
"Orphaned cancel construct";
4630 if ((cct == ClauseCancellationConstructType::Parallel) &&
4631 !mlir::isa<ParallelOp>(structuralParent)) {
4632 return emitOpError() <<
"cancel parallel must appear "
4633 <<
"inside a parallel region";
4635 if (cct == ClauseCancellationConstructType::Loop) {
4638 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
4642 <<
"cancel loop must appear inside a worksharing-loop region";
4644 if (wsloopOp.getNowaitAttr()) {
4645 return emitError() <<
"A worksharing construct that is canceled "
4646 <<
"must not have a nowait clause";
4648 if (wsloopOp.getOrderedAttr()) {
4649 return emitError() <<
"A worksharing construct that is canceled "
4650 <<
"must not have an ordered clause";
4653 }
else if (cct == ClauseCancellationConstructType::Sections) {
4657 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
4659 return emitOpError() <<
"cancel sections must appear "
4660 <<
"inside a sections region";
4662 if (sectionsOp.getNowait()) {
4663 return emitError() <<
"A sections construct that is canceled "
4664 <<
"must not have a nowait clause";
4667 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4668 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4669 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->
getParentOp()))) {
4670 return emitOpError() <<
"cancel taskgroup must appear "
4671 <<
"inside a task region";
4681 const CancellationPointOperands &clauses) {
4682 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4685LogicalResult CancellationPointOp::verify() {
4686 ClauseCancellationConstructType cct = getCancelDirective();
4689 if (!structuralParent)
4690 return emitOpError() <<
"Orphaned cancellation point";
4692 if ((cct == ClauseCancellationConstructType::Parallel) &&
4693 !mlir::isa<ParallelOp>(structuralParent)) {
4694 return emitOpError() <<
"cancellation point parallel must appear "
4695 <<
"inside a parallel region";
4699 if ((cct == ClauseCancellationConstructType::Loop) &&
4700 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
4701 return emitOpError() <<
"cancellation point loop must appear "
4702 <<
"inside a worksharing-loop region";
4704 if ((cct == ClauseCancellationConstructType::Sections) &&
4705 !mlir::isa<omp::SectionOp>(structuralParent)) {
4706 return emitOpError() <<
"cancellation point sections must appear "
4707 <<
"inside a sections region";
4709 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4710 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4711 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->
getParentOp()))) {
4712 return emitOpError() <<
"cancellation point taskgroup must appear "
4713 <<
"inside a task region";
4722LogicalResult MapBoundsOp::verify() {
4723 auto extent = getExtent();
4725 if (!extent && !upperbound)
4726 return emitError(
"expected extent or upperbound.");
4733 PrivateClauseOp::build(
4734 odsBuilder, odsState, symName, type,
4735 DataSharingClauseTypeAttr::get(odsBuilder.
getContext(),
4736 DataSharingClauseType::Private));
4739LogicalResult PrivateClauseOp::verifyRegions() {
4740 Type argType = getArgType();
4741 auto verifyTerminator = [&](
Operation *terminator,
4742 bool yieldsValue) -> LogicalResult {
4746 if (!llvm::isa<YieldOp>(terminator))
4748 <<
"expected exit block terminator to be an `omp.yield` op.";
4750 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4751 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4754 if (yieldedTypes.empty())
4758 <<
"Did not expect any values to be yielded.";
4761 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4765 <<
"Invalid yielded value. Expected type: " << argType
4768 if (yieldedTypes.empty())
4771 error << yieldedTypes;
4777 StringRef regionName,
4778 bool yieldsValue) -> LogicalResult {
4779 assert(!region.
empty());
4783 <<
"`" << regionName <<
"`: "
4784 <<
"expected " << expectedNumArgs
4787 for (
Block &block : region) {
4789 if (!block.mightHaveTerminator())
4792 if (
failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4800 for (
Region *region : getRegions())
4801 for (
Type ty : region->getArgumentTypes())
4803 return emitError() <<
"Region argument type mismatch: got " << ty
4804 <<
" expected " << argType <<
".";
4807 if (!initRegion.
empty() &&
4812 DataSharingClauseType dsType = getDataSharingType();
4814 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4815 return emitError(
"`private` clauses do not require a `copy` region.");
4817 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4819 "`firstprivate` clauses require at least a `copy` region.");
4821 if (dsType == DataSharingClauseType::FirstPrivate &&
4826 if (!getDeallocRegion().empty() &&
4839 const MaskedOperands &clauses) {
4840 MaskedOp::build(builder, state, clauses.filteredThreadId);
4848 const ScanOperands &clauses) {
4849 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4852LogicalResult ScanOp::verify() {
4853 if (hasExclusiveVars() == hasInclusiveVars())
4855 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4856 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4857 if (parentWsLoopOp.getReductionModAttr() &&
4858 parentWsLoopOp.getReductionModAttr().getValue() ==
4859 ReductionModifier::inscan)
4862 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4863 if (parentSimdOp.getReductionModAttr() &&
4864 parentSimdOp.getReductionModAttr().getValue() ==
4865 ReductionModifier::inscan)
4868 return emitError(
"SCAN directive needs to be enclosed within a parent "
4869 "worksharing loop construct or SIMD construct with INSCAN "
4870 "reduction modifier");
4875 std::optional<uint64_t> alignment) {
4876 if (alignment.has_value()) {
4877 if ((alignment.value() != 0) && !llvm::has_single_bit(alignment.value()))
4879 <<
"ALIGN value : " << alignment.value() <<
" must be power of 2";
4884LogicalResult AllocateDirOp::verify() {
4892LogicalResult AllocSharedMemOp::verify() {
4900LogicalResult FreeSharedMemOp::verify() {
4908LogicalResult WorkdistributeOp::verify() {
4910 Region ®ion = getRegion();
4915 if (entryBlock.
empty())
4916 return emitOpError(
"region must contain a structured block");
4918 bool hasTerminator =
false;
4919 for (
Block &block : region) {
4920 if (isa<TerminatorOp>(block.back())) {
4921 if (hasTerminator) {
4922 return emitOpError(
"region must have exactly one terminator");
4924 hasTerminator =
true;
4927 if (!hasTerminator) {
4928 return emitOpError(
"region must be terminated with omp.terminator");
4932 if (isa<BarrierOp>(op)) {
4934 "explicit barriers are not allowed in workdistribute region");
4937 if (isa<ParallelOp>(op)) {
4939 "nested parallel constructs not allowed in workdistribute");
4941 if (isa<TeamsOp>(op)) {
4943 "nested teams constructs not allowed in workdistribute");
4947 if (walkResult.wasInterrupted())
4951 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4952 return emitOpError(
"workdistribute must be nested under teams");
4960LogicalResult DeclareSimdOp::verify() {
4963 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4965 return emitOpError() <<
"must be nested inside a function";
4967 if (getInbranch() && getNotinbranch())
4968 return emitOpError(
"cannot have both 'inbranch' and 'notinbranch'");
4978 const DeclareSimdOperands &clauses) {
4980 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4982 clauses.linearVars, clauses.linearStepVars,
4983 clauses.linearVarTypes, clauses.linearModifiers,
4984 clauses.notinbranch, clauses.simdlen,
4985 clauses.uniformVars);
5002 return mlir::failure();
5003 return mlir::success();
5010 for (
unsigned i = 0; i < uniformVars.size(); ++i) {
5013 p << uniformVars[i] <<
" : " << uniformTypes[i];
5028 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
5029 [&]() -> ParseResult {
return success(); })))
5063 OpAsmParser::Argument &arg = ivArgs.emplace_back();
5064 if (parser.parseArgument(arg))
5068 if (succeeded(parser.parseOptionalColon())) {
5069 if (parser.parseType(arg.type))
5072 arg.type = parser.getBuilder().getIndexType();
5084 OpAsmParser::UnresolvedOperand lb, ub, st;
5085 if (parser.parseOperand(lb) || parser.parseKeyword(
"to") ||
5086 parser.parseOperand(ub) || parser.parseKeyword(
"step") ||
5087 parser.parseOperand(st))
5092 steps.push_back(st);
5100 if (ivArgs.size() != lbs.size())
5102 <<
"mismatch: " << ivArgs.size() <<
" variables but " << lbs.size()
5105 for (
auto &arg : ivArgs) {
5106 lbTypes.push_back(arg.type);
5107 ubTypes.push_back(arg.type);
5108 stepTypes.push_back(arg.type);
5128 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
5131 p << lbs[i] <<
" to " << ubs[i] <<
" step " << steps[i];
5139LogicalResult IteratorOp::verify() {
5140 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().
getType());
5142 return emitOpError() <<
"result must be omp.iterated<entry_ty>";
5144 for (
auto [lb,
ub, step] : llvm::zip_equal(
5145 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
5147 return emitOpError() <<
"loop step must not be zero";
5151 IntegerAttr stepAttr;
5157 const APInt &lbVal = lbAttr.getValue();
5158 const APInt &ubVal = ubAttr.getValue();
5159 const APInt &stepVal = stepAttr.getValue();
5160 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
5161 return emitOpError() <<
"positive loop step requires lower bound to be "
5162 "less than or equal to upper bound";
5163 if (stepVal.isNegative() && lbVal.slt(ubVal))
5164 return emitOpError() <<
"negative loop step requires lower bound to be "
5165 "greater than or equal to upper bound";
5168 Block &
b = getRegion().front();
5169 auto yield = llvm::dyn_cast<omp::YieldOp>(
b.getTerminator());
5172 return emitOpError() <<
"region must be terminated by omp.yield";
5174 if (yield.getNumOperands() != 1)
5176 <<
"omp.yield in omp.iterator region must yield exactly one value";
5178 mlir::Type yieldedTy = yield.getOperand(0).getType();
5179 mlir::Type elemTy = iteratedTy.getElementType();
5181 if (yieldedTy != elemTy)
5182 return emitOpError() <<
"omp.iterated element type (" << elemTy
5183 <<
") does not match omp.yield operand type ("
5184 << yieldedTy <<
")";
5197 return emitOpError() <<
"expected symbol reference '" << getSymName()
5198 <<
"' to point to a global variable";
5200 if (isa<FunctionOpInterface>(symbol))
5201 return emitOpError() <<
"expected symbol reference '" << getSymName()
5202 <<
"' to point to a global variable, not a function";
5207#define GET_ATTRDEF_CLASSES
5208#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
5210#define GET_OP_CLASSES
5211#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
5213#define GET_TYPEDEF_CLASSES
5214#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static Type getElementType(Type type)
Determine the element type of type.
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static void printHeapAllocClause(OpAsmPrinter &p, Operation *op, TypeAttr inType, ValueRange typeparams, TypeRange typeparamsTypes, ValueRange shape, TypeRange shapeTypes)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op, AccessGroupModifierAttr modifierFirst, FallbackModifierAttr modifierSecond, Value dynGroupprivateSize, Type sizeType)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseHeapAllocClause(OpAsmParser &parser, TypeAttr &inTypeAttr, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &typeparams, SmallVectorImpl< Type > &typeparamsTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &shape, SmallVectorImpl< Type > &shapeTypes)
operation ::= $in_type ( ( $typeparams ) )? ( , $shape )?
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static ParseResult parseDynGroupprivateClause(OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr, FallbackModifierAttr &fallbackAttr, std::optional< OpAsmParser::UnresolvedOperand > &dynGroupprivateSize, Type &sizeType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup, FallbackModifierAttr fallback, Value dynGroupprivateSize)
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
LogicalResult verifyAlignment(Operation &op, std::optional< uint64_t > alignment)
Verifies align clause in allocate directive.
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.