28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/STLForwardCompat.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringRef.h"
35#include "llvm/ADT/TypeSwitch.h"
36#include "llvm/ADT/bit.h"
37#include "llvm/Support/InterleavedRange.h"
43#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
46#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
53 return attrs.empty() ?
nullptr : ArrayAttr::get(context, attrs);
67struct MemRefPointerLikeModel
68 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
71 return llvm::cast<MemRefType>(pointer).getElementType();
75struct LLVMPointerPointerLikeModel
76 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77 LLVM::LLVMPointerType> {
102 bool isRegionArgOfOp;
112 assert(isRegionArgOfOp &&
"Must describe a region operand");
115 size_t &getArgIdx() {
116 assert(isRegionArgOfOp &&
"Must describe a region operand");
121 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
125 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
128 bool isLoopOp()
const {
129 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
130 return isa<CanonicalLoopOp>(op);
132 Region *&getParentRegion() {
133 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
136 size_t &getLoopDepth() {
137 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
141 void skipIf(
bool v =
true) { skip = skip || v; }
159 llvm::ReversePostOrderTraversal<Block *> traversal(&r->
getBlocks().front());
162 size_t sequentialIdx = -1;
163 bool isOnlyContainerOp =
true;
164 for (
Block *
b : traversal) {
166 if (&op == o && !found) {
170 if (op.getNumRegions()) {
173 isOnlyContainerOp =
false;
175 if (found && !isOnlyContainerOp)
180 Component &containerOpInRegion = components.emplace_back();
181 containerOpInRegion.isRegionArgOfOp =
false;
182 containerOpInRegion.isUnique = isOnlyContainerOp;
183 containerOpInRegion.getContainerOp() = o;
184 containerOpInRegion.getOpPos() = sequentialIdx;
185 containerOpInRegion.getParentRegion() = r;
190 Component ®ionArgOfOperation = components.emplace_back();
191 regionArgOfOperation.isRegionArgOfOp =
true;
192 regionArgOfOperation.isUnique =
true;
193 regionArgOfOperation.getArgIdx() = 0;
194 regionArgOfOperation.getOwnerOp() = parent;
206 for (
auto [idx, region] : llvm::enumerate(o->
getRegions())) {
210 llvm_unreachable(
"Region not child of its parent operation");
212 regionArgOfOperation.isUnique =
false;
213 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
221 for (Component &c : components)
222 c.skipIf(c.isRegionArgOfOp && c.isUnique);
225 size_t numSurroundingLoops = 0;
226 for (Component &c : llvm::reverse(components)) {
231 if (c.isRegionArgOfOp) {
232 numSurroundingLoops = 0;
239 numSurroundingLoops = 0;
241 c.getLoopDepth() = numSurroundingLoops;
244 if (isa<CanonicalLoopOp>(c.getContainerOp()))
245 numSurroundingLoops += 1;
250 bool isLoopNest =
false;
251 for (Component &c : components) {
252 if (c.skip || c.isRegionArgOfOp)
255 if (!isLoopNest && c.getLoopDepth() >= 1) {
258 }
else if (isLoopNest) {
260 c.skipIf(c.isUnique);
264 if (c.getLoopDepth() == 0)
271 for (Component &c : components)
272 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
273 !isa<CanonicalLoopOp>(c.getContainerOp()));
277 bool newRegion =
true;
278 for (Component &c : llvm::reverse(components)) {
279 c.skipIf(newRegion && c.isUnique);
286 if (!c.isRegionArgOfOp && c.getContainerOp())
292 llvm::raw_svector_ostream NameOS(Name);
293 for (
auto &c : llvm::reverse(components)) {
297 if (c.isRegionArgOfOp)
298 NameOS <<
"_r" << c.getArgIdx();
299 else if (c.getLoopDepth() >= 1)
300 NameOS <<
"_d" << c.getLoopDepth();
302 NameOS <<
"_s" << c.getOpPos();
305 return NameOS.str().str();
308void OpenMPDialect::initialize() {
311#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
314#define GET_ATTRDEF_LIST
315#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
318#define GET_TYPEDEF_LIST
319#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
322 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
324 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
325 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
330 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
336 mlir::LLVM::GlobalOp::attachInterface<
339 mlir::LLVM::LLVMFuncOp::attachInterface<
342 mlir::func::FuncOp::attachInterface<
368 allocatorVars.push_back(operand);
369 allocatorTypes.push_back(type);
375 allocateVars.push_back(operand);
376 allocateTypes.push_back(type);
387 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
388 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
389 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
390 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
398template <
typename ClauseAttr>
400 using ClauseT =
decltype(std::declval<ClauseAttr>().getValue());
405 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
406 attr = ClauseAttr::get(parser.
getContext(), *enumValue);
409 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
412template <
typename ClauseAttr>
414 p << stringifyEnum(attr.getValue());
439 std::optional<omp::LinearModifier> linearModifier;
441 linearModifier = omp::LinearModifier::val;
443 linearModifier = omp::LinearModifier::ref;
445 linearModifier = omp::LinearModifier::uval;
448 bool hasLinearModifierParens = linearModifier.has_value();
449 if (hasLinearModifierParens && parser.
parseLParen())
457 if (hasLinearModifierParens && parser.
parseRParen())
460 linearVars.push_back(var);
461 linearTypes.push_back(type);
462 linearStepVars.push_back(stepVar);
463 linearStepTypes.push_back(stepType);
464 if (linearModifier) {
466 omp::LinearModifierAttr::get(parser.
getContext(), *linearModifier));
468 modifiers.push_back(UnitAttr::get(parser.
getContext()));
474 linearModifiers = ArrayAttr::get(parser.
getContext(), modifiers);
483 size_t linearVarsSize = linearVars.size();
484 for (
unsigned i = 0; i < linearVarsSize; ++i) {
488 Attribute modAttr = linearModifiers ? linearModifiers[i] :
nullptr;
489 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) :
nullptr;
491 p << omp::stringifyLinearModifier(mod.getValue()) <<
"(";
493 p << linearVars[i] <<
" : " << linearTypes[i];
494 p <<
" = " << linearStepVars[i] <<
" : " << stepVarTypes[i];
510 if (!linearModifiers)
512 if (linearModifiers->size() != linearVars.size())
514 <<
"expected as many linear modifiers as linear variables";
515 if (!isDeclareSimd) {
516 for (
Attribute attr : *linearModifiers) {
519 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
522 omp::LinearModifier mod = modAttr.getValue();
523 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
525 <<
"linear modifier '" << omp::stringifyLinearModifier(mod)
526 <<
"' may only be specified on a declare simd directive";
541 for (
const auto &it : nontemporalVars)
542 if (!nontemporalItems.insert(it).second)
543 return op->
emitOpError() <<
"nontemporal variable used more than once";
552 std::optional<ArrayAttr> alignments,
555 if (!alignedVars.empty()) {
556 if (!alignments || alignments->size() != alignedVars.size())
558 <<
"expected as many alignment values as aligned variables";
561 return op->
emitOpError() <<
"unexpected alignment values attribute";
567 for (
auto it : alignedVars)
568 if (!alignedItems.insert(it).second)
569 return op->
emitOpError() <<
"aligned variable used more than once";
575 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
576 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
577 if (intAttr.getValue().sle(0))
578 return op->
emitOpError() <<
"alignment should be greater than 0";
580 return op->
emitOpError() <<
"expected integer alignment";
597 if (parser.parseOperand(alignedVars.emplace_back()) ||
598 parser.parseColonType(alignedTypes.emplace_back()) ||
599 parser.parseArrow() ||
600 parser.parseAttribute(alignmentVec.emplace_back())) {
607 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
614 std::optional<ArrayAttr> alignments) {
615 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
618 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
619 p <<
" -> " << (*alignments)[i];
630 if (modifiers.size() > 2)
632 for (
const auto &mod : modifiers) {
635 auto symbol = symbolizeScheduleModifier(mod);
638 <<
" unknown modifier type: " << mod;
643 if (modifiers.size() == 1) {
644 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
645 modifiers.push_back(modifiers[0]);
646 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
648 }
else if (modifiers.size() == 2) {
651 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
652 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
654 <<
" incorrect modifier order";
670 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
671 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
676 std::optional<mlir::omp::ClauseScheduleKind> schedule =
677 symbolizeClauseScheduleKind(keyword);
681 scheduleAttr = ClauseScheduleKindAttr::get(parser.
getContext(), *schedule);
683 case ClauseScheduleKind::Static:
684 case ClauseScheduleKind::Dynamic:
685 case ClauseScheduleKind::Guided:
691 chunkSize = std::nullopt;
694 case ClauseScheduleKind::Auto:
695 case ClauseScheduleKind::Runtime:
696 case ClauseScheduleKind::Distribute:
697 chunkSize = std::nullopt;
706 modifiers.push_back(mod);
712 if (!modifiers.empty()) {
714 if (std::optional<ScheduleModifier> mod =
715 symbolizeScheduleModifier(modifiers[0])) {
716 scheduleMod = ScheduleModifierAttr::get(parser.
getContext(), *mod);
718 return parser.
emitError(loc,
"invalid schedule modifier");
721 if (modifiers.size() > 1) {
722 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
732 ClauseScheduleKindAttr scheduleKind,
733 ScheduleModifierAttr scheduleMod,
734 UnitAttr scheduleSimd,
Value scheduleChunk,
735 Type scheduleChunkType) {
736 p << stringifyClauseScheduleKind(scheduleKind.getValue());
738 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
740 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
752 ClauseOrderKindAttr &order,
753 OrderModifierAttr &orderMod) {
758 if (std::optional<OrderModifier> enumValue =
759 symbolizeOrderModifier(enumStr)) {
760 orderMod = OrderModifierAttr::get(parser.
getContext(), *enumValue);
767 if (std::optional<ClauseOrderKind> enumValue =
768 symbolizeClauseOrderKind(enumStr)) {
769 order = ClauseOrderKindAttr::get(parser.
getContext(), *enumValue);
772 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
776 ClauseOrderKindAttr order,
777 OrderModifierAttr orderMod) {
779 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
781 p << stringifyClauseOrderKind(order.getValue());
784template <
typename ClauseTypeAttr,
typename ClauseType>
787 std::optional<OpAsmParser::UnresolvedOperand> &operand,
789 std::optional<ClauseType> (*symbolizeClause)(StringRef),
790 StringRef clauseName) {
793 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
794 prescriptiveness = ClauseTypeAttr::get(parser.
getContext(), *enumValue);
799 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
809 <<
"expected " << clauseName <<
" operand";
812 if (operand.has_value()) {
820template <
typename ClauseTypeAttr,
typename ClauseType>
823 ClauseTypeAttr prescriptiveness,
Value operand,
825 StringRef (*stringifyClauseType)(ClauseType)) {
827 if (prescriptiveness)
828 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
831 p << operand <<
": " << operandType;
841 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
842 Type &grainsizeType) {
844 parser, grainsizeMod, grainsize, grainsizeType,
845 &symbolizeClauseGrainsizeType,
"grainsize");
849 ClauseGrainsizeTypeAttr grainsizeMod,
852 p, op, grainsizeMod, grainsize, grainsizeType,
853 &stringifyClauseGrainsizeType);
863 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
864 Type &numTasksType) {
866 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
871 ClauseNumTasksTypeAttr numTasksMod,
874 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
883 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
884 SmallVectorImpl<Type> &types;
885 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
886 SmallVectorImpl<Type> &types)
887 : vars(vars), types(types) {}
889struct PrivateParseArgs {
890 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
891 llvm::SmallVectorImpl<Type> &types;
893 UnitAttr &needsBarrier;
895 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
896 SmallVectorImpl<Type> &types,
ArrayAttr &syms,
897 UnitAttr &needsBarrier,
899 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
900 mapIndices(mapIndices) {}
903struct ReductionParseArgs {
904 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
905 SmallVectorImpl<Type> &types;
908 ReductionModifierAttr *modifier;
909 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
911 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
912 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
915struct AllRegionParseArgs {
916 std::optional<MapParseArgs> hasDeviceAddrArgs;
917 std::optional<MapParseArgs> hostEvalArgs;
918 std::optional<ReductionParseArgs> inReductionArgs;
919 std::optional<MapParseArgs> mapArgs;
920 std::optional<PrivateParseArgs> privateArgs;
921 std::optional<ReductionParseArgs> reductionArgs;
922 std::optional<ReductionParseArgs> taskReductionArgs;
923 std::optional<MapParseArgs> useDeviceAddrArgs;
924 std::optional<MapParseArgs> useDevicePtrArgs;
929 return "private_barrier";
939 ReductionModifierAttr *modifier =
nullptr,
940 UnitAttr *needsBarrier =
nullptr) {
944 unsigned regionArgOffset = regionPrivateArgs.size();
954 std::optional<ReductionModifier> enumValue =
955 symbolizeReductionModifier(enumStr);
956 if (!enumValue.has_value())
958 *modifier = ReductionModifierAttr::get(parser.
getContext(), *enumValue);
965 isByRefVec.push_back(
966 parser.parseOptionalKeyword(
"byref").succeeded());
968 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
971 if (parser.parseOperand(operands.emplace_back()) ||
972 parser.parseArrow() ||
973 parser.parseArgument(regionPrivateArgs.emplace_back()))
977 if (parser.parseOptionalLSquare().succeeded()) {
978 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
979 parser.parseInteger(mapIndicesVec.emplace_back()) ||
980 parser.parseRSquare())
983 mapIndicesVec.push_back(-1);
995 if (parser.parseType(types.emplace_back()))
1002 if (operands.size() != types.size())
1011 *needsBarrier = mlir::UnitAttr::get(parser.
getContext());
1014 auto *argsBegin = regionPrivateArgs.begin();
1016 argsBegin + regionArgOffset + types.size());
1017 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1023 *symbols = ArrayAttr::get(parser.
getContext(), symbolAttrs);
1026 if (!mapIndicesVec.empty())
1039 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1054 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1060 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1061 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
1062 nullptr, &privateArgs->needsBarrier)))
1071 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1076 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1077 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
1078 reductionArgs->modifier)))
1085 AllRegionParseArgs args) {
1089 args.hasDeviceAddrArgs)))
1091 <<
"invalid `has_device_addr` format";
1094 args.hostEvalArgs)))
1096 <<
"invalid `host_eval` format";
1099 args.inReductionArgs)))
1101 <<
"invalid `in_reduction` format";
1106 <<
"invalid `map_entries` format";
1111 <<
"invalid `private` format";
1114 args.reductionArgs)))
1116 <<
"invalid `reduction` format";
1119 args.taskReductionArgs)))
1121 <<
"invalid `task_reduction` format";
1124 args.useDeviceAddrArgs)))
1126 <<
"invalid `use_device_addr` format";
1129 args.useDevicePtrArgs)))
1131 <<
"invalid `use_device_addr` format";
1133 return parser.
parseRegion(region, entryBlockArgs);
1152 AllRegionParseArgs args;
1153 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1154 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1155 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1156 inReductionByref, inReductionSyms);
1157 args.mapArgs.emplace(mapVars, mapTypes);
1158 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1159 privateNeedsBarrier, &privateMaps);
1170 UnitAttr &privateNeedsBarrier) {
1171 AllRegionParseArgs args;
1172 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1173 inReductionByref, inReductionSyms);
1174 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1175 privateNeedsBarrier);
1186 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1190 AllRegionParseArgs args;
1191 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1192 inReductionByref, inReductionSyms);
1193 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1194 privateNeedsBarrier);
1195 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1196 reductionSyms, &reductionMod);
1204 UnitAttr &privateNeedsBarrier) {
1205 AllRegionParseArgs args;
1206 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1207 privateNeedsBarrier);
1215 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1219 AllRegionParseArgs args;
1220 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1221 privateNeedsBarrier);
1222 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1223 reductionSyms, &reductionMod);
1232 AllRegionParseArgs args;
1233 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1234 taskReductionByref, taskReductionSyms);
1244 AllRegionParseArgs args;
1245 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1246 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1255struct MapPrintArgs {
1260struct PrivatePrintArgs {
1264 UnitAttr needsBarrier;
1268 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1269 mapIndices(mapIndices) {}
1271struct ReductionPrintArgs {
1276 ReductionModifierAttr modifier;
1278 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
1279 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1281struct AllRegionPrintArgs {
1282 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1283 std::optional<MapPrintArgs> hostEvalArgs;
1284 std::optional<ReductionPrintArgs> inReductionArgs;
1285 std::optional<MapPrintArgs> mapArgs;
1286 std::optional<PrivatePrintArgs> privateArgs;
1287 std::optional<ReductionPrintArgs> reductionArgs;
1288 std::optional<ReductionPrintArgs> taskReductionArgs;
1289 std::optional<MapPrintArgs> useDeviceAddrArgs;
1290 std::optional<MapPrintArgs> useDevicePtrArgs;
1299 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
1300 if (argsSubrange.empty())
1303 p << clauseName <<
"(";
1306 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1310 symbols = ArrayAttr::get(ctx, values);
1323 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1324 mapIndices.asArrayRef(),
1325 byref.asArrayRef()),
1327 auto [op, arg, sym, map, isByRef] = t;
1333 p << op <<
" -> " << arg;
1336 p <<
" [map_idx=" << map <<
"]";
1339 llvm::interleaveComma(types, p);
1347 StringRef clauseName,
ValueRange argsSubrange,
1348 std::optional<MapPrintArgs> mapArgs) {
1355 StringRef clauseName,
ValueRange argsSubrange,
1356 std::optional<PrivatePrintArgs> privateArgs) {
1359 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1360 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1361 nullptr, privateArgs->needsBarrier);
1367 std::optional<ReductionPrintArgs> reductionArgs) {
1370 reductionArgs->vars, reductionArgs->types,
1371 reductionArgs->syms,
nullptr,
1372 reductionArgs->byref, reductionArgs->modifier);
1376 const AllRegionPrintArgs &args) {
1377 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1381 iface.getHasDeviceAddrBlockArgs(),
1382 args.hasDeviceAddrArgs);
1386 args.inReductionArgs);
1392 args.reductionArgs);
1394 iface.getTaskReductionBlockArgs(),
1395 args.taskReductionArgs);
1397 iface.getUseDeviceAddrBlockArgs(),
1398 args.useDeviceAddrArgs);
1400 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1416 AllRegionPrintArgs args;
1417 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1418 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1419 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1420 inReductionByref, inReductionSyms);
1421 args.mapArgs.emplace(mapVars, mapTypes);
1422 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1423 privateNeedsBarrier, privateMaps);
1431 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1432 AllRegionPrintArgs args;
1433 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1434 inReductionByref, inReductionSyms);
1435 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1436 privateNeedsBarrier,
1445 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1446 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1449 AllRegionPrintArgs args;
1450 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1451 inReductionByref, inReductionSyms);
1452 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1453 privateNeedsBarrier,
1455 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1456 reductionSyms, reductionMod);
1463 UnitAttr privateNeedsBarrier) {
1464 AllRegionPrintArgs args;
1465 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1466 privateNeedsBarrier,
1474 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1477 AllRegionPrintArgs args;
1478 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1479 privateNeedsBarrier,
1481 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1482 reductionSyms, reductionMod);
1492 AllRegionPrintArgs args;
1493 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1494 taskReductionByref, taskReductionSyms);
1504 AllRegionPrintArgs args;
1505 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1506 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1510template <
typename ParsePrefixFn>
1519 if (failed(parsePrefix()))
1527 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1528 iteratedVars.push_back(v);
1529 iteratedTypes.push_back(ty);
1531 plainVars.push_back(v);
1532 plainTypes.push_back(ty);
1538template <
typename Pr
intPrefixFn>
1542 PrintPrefixFn &&printPrefixForPlain,
1543 PrintPrefixFn &&printPrefixForIterated) {
1550 p << v <<
" : " << t;
1554 for (
unsigned i = 0; i < iteratedVars.size(); ++i)
1555 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1556 for (
unsigned i = 0; i < plainVars.size(); ++i)
1557 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1565 if (!reductionVars.empty()) {
1566 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1568 <<
"expected as many reduction symbol references "
1569 "as reduction variables";
1570 if (reductionByref && reductionByref->size() != reductionVars.size())
1571 return op->
emitError() <<
"expected as many reduction variable by "
1572 "reference attributes as reduction variables";
1575 return op->
emitOpError() <<
"unexpected reduction symbol references";
1582 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1583 Value accum = std::get<0>(args);
1585 if (!accumulators.insert(accum).second)
1586 return op->
emitOpError() <<
"accumulator variable used more than once";
1589 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1593 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1594 <<
" to point to a reduction declaration";
1596 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1598 <<
"expected accumulator (" << varType
1599 <<
") to be the same type as reduction declaration ("
1600 << decl.getAccumulatorType() <<
")";
1619 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1620 parser.parseArrow() ||
1621 parser.parseAttribute(symsVec.emplace_back()) ||
1622 parser.parseColonType(copyprivateTypes.emplace_back()))
1628 copyprivateSyms = ArrayAttr::get(parser.
getContext(), syms);
1636 std::optional<ArrayAttr> copyprivateSyms) {
1637 if (!copyprivateSyms.has_value())
1639 llvm::interleaveComma(
1640 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1641 [&](
const auto &args) {
1642 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1643 << std::get<2>(args);
1650 std::optional<ArrayAttr> copyprivateSyms) {
1651 size_t copyprivateSymsSize =
1652 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1653 if (copyprivateSymsSize != copyprivateVars.size())
1654 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1655 << copyprivateVars.size()
1656 <<
") and functions (= " << copyprivateSymsSize
1657 <<
"), both must be equal";
1658 if (!copyprivateSyms.has_value())
1661 for (
auto copyprivateVarAndSym :
1662 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1664 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1665 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1667 if (mlir::func::FuncOp mlirFuncOp =
1670 funcOp = mlirFuncOp;
1671 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1674 funcOp = llvmFuncOp;
1676 auto getNumArguments = [&] {
1677 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1680 auto getArgumentType = [&](
unsigned i) {
1681 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1686 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1687 <<
" to point to a copy function";
1689 if (getNumArguments() != 2)
1691 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1693 Type argTy = getArgumentType(0);
1694 if (argTy != getArgumentType(1))
1695 return op->
emitOpError() <<
"expected copy function " << symbolRef
1696 <<
" arguments to have the same type";
1698 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1699 if (argTy != varType)
1701 <<
"expected copy function arguments' type (" << argTy
1702 <<
") to be the same as copyprivate variable's type (" << varType
1727 OpAsmParser::UnresolvedOperand operand;
1729 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1730 parser.parseOperand(operand) || parser.parseColonType(ty))
1732 std::optional<ClauseTaskDepend> keywordDepend =
1733 symbolizeClauseTaskDepend(keyword);
1737 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1738 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1739 iteratedVars.push_back(operand);
1740 iteratedTypes.push_back(ty);
1741 iterKindsVec.push_back(kindAttr);
1743 dependVars.push_back(operand);
1744 dependTypes.push_back(ty);
1745 kindsVec.push_back(kindAttr);
1751 dependKinds = ArrayAttr::get(parser.
getContext(), kinds);
1753 iteratedKinds = ArrayAttr::get(parser.
getContext(), iterKinds);
1760 std::optional<ArrayAttr> dependKinds,
1763 std::optional<ArrayAttr> iteratedKinds) {
1766 std::optional<ArrayAttr> kinds) {
1767 for (
unsigned i = 0, e = vars.size(); i < e; ++i) {
1770 p << stringifyClauseTaskDepend(
1771 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1773 <<
" -> " << vars[i] <<
" : " << types[i];
1777 printEntries(dependVars, dependTypes, dependKinds);
1778 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1783 std::optional<ArrayAttr> dependKinds,
1785 std::optional<ArrayAttr> iteratedKinds,
1787 if (!dependVars.empty()) {
1788 if (!dependKinds || dependKinds->size() != dependVars.size())
1789 return op->
emitOpError() <<
"expected as many depend values"
1790 " as depend variables";
1792 if (dependKinds && !dependKinds->empty())
1793 return op->
emitOpError() <<
"unexpected depend values";
1796 if (!iteratedVars.empty()) {
1797 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1798 return op->
emitOpError() <<
"expected as many depend iterated values"
1799 " as depend iterated variables";
1801 if (iteratedKinds && !iteratedKinds->empty())
1802 return op->
emitOpError() <<
"unexpected depend iterated values";
1817 IntegerAttr &hintAttr) {
1818 StringRef hintKeyword;
1824 auto parseKeyword = [&]() -> ParseResult {
1827 if (hintKeyword ==
"uncontended")
1829 else if (hintKeyword ==
"contended")
1831 else if (hintKeyword ==
"nonspeculative")
1833 else if (hintKeyword ==
"speculative")
1837 << hintKeyword <<
" is not a valid hint";
1848 IntegerAttr hintAttr) {
1849 int64_t hint = hintAttr.getInt();
1857 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1859 bool uncontended = bitn(hint, 0);
1860 bool contended = bitn(hint, 1);
1861 bool nonspeculative = bitn(hint, 2);
1862 bool speculative = bitn(hint, 3);
1866 hints.push_back(
"uncontended");
1868 hints.push_back(
"contended");
1870 hints.push_back(
"nonspeculative");
1872 hints.push_back(
"speculative");
1874 llvm::interleaveComma(hints, p);
1881 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1883 bool uncontended = bitn(hint, 0);
1884 bool contended = bitn(hint, 1);
1885 bool nonspeculative = bitn(hint, 2);
1886 bool speculative = bitn(hint, 3);
1888 if (uncontended && contended)
1889 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1890 "omp_sync_hint_contended cannot be combined";
1891 if (nonspeculative && speculative)
1892 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1893 "omp_sync_hint_speculative cannot be combined.";
1904 return (value & flag) == flag;
1912static ParseResult parseMapClause(
OpAsmParser &parser,
1913 ClauseMapFlagsAttr &mapType) {
1914 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1917 auto parseTypeAndMod = [&]() -> ParseResult {
1918 StringRef mapTypeMod;
1922 if (mapTypeMod ==
"always")
1923 mapTypeBits |= ClauseMapFlags::always;
1925 if (mapTypeMod ==
"implicit")
1926 mapTypeBits |= ClauseMapFlags::implicit;
1928 if (mapTypeMod ==
"ompx_hold")
1929 mapTypeBits |= ClauseMapFlags::ompx_hold;
1931 if (mapTypeMod ==
"close")
1932 mapTypeBits |= ClauseMapFlags::close;
1934 if (mapTypeMod ==
"present")
1935 mapTypeBits |= ClauseMapFlags::present;
1937 if (mapTypeMod ==
"to")
1938 mapTypeBits |= ClauseMapFlags::to;
1940 if (mapTypeMod ==
"from")
1941 mapTypeBits |= ClauseMapFlags::from;
1943 if (mapTypeMod ==
"tofrom")
1944 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1946 if (mapTypeMod ==
"delete")
1947 mapTypeBits |= ClauseMapFlags::del;
1949 if (mapTypeMod ==
"storage")
1950 mapTypeBits |= ClauseMapFlags::storage;
1952 if (mapTypeMod ==
"return_param")
1953 mapTypeBits |= ClauseMapFlags::return_param;
1955 if (mapTypeMod ==
"private")
1956 mapTypeBits |= ClauseMapFlags::priv;
1958 if (mapTypeMod ==
"literal")
1959 mapTypeBits |= ClauseMapFlags::literal;
1961 if (mapTypeMod ==
"attach")
1962 mapTypeBits |= ClauseMapFlags::attach;
1964 if (mapTypeMod ==
"attach_always")
1965 mapTypeBits |= ClauseMapFlags::attach_always;
1967 if (mapTypeMod ==
"attach_never")
1968 mapTypeBits |= ClauseMapFlags::attach_never;
1970 if (mapTypeMod ==
"attach_auto")
1971 mapTypeBits |= ClauseMapFlags::attach_auto;
1973 if (mapTypeMod ==
"ref_ptr")
1974 mapTypeBits |= ClauseMapFlags::ref_ptr;
1976 if (mapTypeMod ==
"ref_ptee")
1977 mapTypeBits |= ClauseMapFlags::ref_ptee;
1979 if (mapTypeMod ==
"ref_ptr_ptee")
1980 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1982 if (mapTypeMod ==
"is_device_ptr")
1983 mapTypeBits |= ClauseMapFlags::is_device_ptr;
2000 ClauseMapFlagsAttr mapType) {
2002 ClauseMapFlags mapFlags = mapType.getValue();
2007 mapTypeStrs.push_back(
"always");
2009 mapTypeStrs.push_back(
"implicit");
2011 mapTypeStrs.push_back(
"ompx_hold");
2013 mapTypeStrs.push_back(
"close");
2015 mapTypeStrs.push_back(
"present");
2024 mapTypeStrs.push_back(
"tofrom");
2026 mapTypeStrs.push_back(
"from");
2028 mapTypeStrs.push_back(
"to");
2031 mapTypeStrs.push_back(
"delete");
2033 mapTypeStrs.push_back(
"return_param");
2035 mapTypeStrs.push_back(
"storage");
2037 mapTypeStrs.push_back(
"private");
2039 mapTypeStrs.push_back(
"literal");
2041 mapTypeStrs.push_back(
"attach");
2043 mapTypeStrs.push_back(
"attach_always");
2045 mapTypeStrs.push_back(
"attach_never");
2047 mapTypeStrs.push_back(
"attach_auto");
2049 mapTypeStrs.push_back(
"ref_ptr");
2051 mapTypeStrs.push_back(
"ref_ptee");
2053 mapTypeStrs.push_back(
"ref_ptr_ptee");
2055 mapTypeStrs.push_back(
"is_device_ptr");
2056 if (mapFlags == ClauseMapFlags::none)
2057 mapTypeStrs.push_back(
"none");
2059 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2060 p << mapTypeStrs[i];
2061 if (i + 1 < mapTypeStrs.size()) {
2067static ParseResult parseMembersIndex(
OpAsmParser &parser,
2071 auto parseIndices = [&]() -> ParseResult {
2076 APInt(64, value,
false)));
2090 memberIdxs.push_back(ArrayAttr::get(parser.
getContext(), values));
2094 if (!memberIdxs.empty())
2095 membersIdx = ArrayAttr::get(parser.
getContext(), memberIdxs);
2105 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
2107 auto memberIdx = cast<ArrayAttr>(v);
2108 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
2109 p << cast<IntegerAttr>(v2).getInt();
2116 VariableCaptureKindAttr mapCaptureType) {
2117 std::string typeCapStr;
2118 llvm::raw_string_ostream typeCap(typeCapStr);
2119 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2121 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2122 typeCap <<
"ByCopy";
2123 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2124 typeCap <<
"VLAType";
2125 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2131 VariableCaptureKindAttr &mapCaptureType) {
2132 StringRef mapCaptureKey;
2136 if (mapCaptureKey ==
"This")
2137 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2138 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
2139 if (mapCaptureKey ==
"ByRef")
2140 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2141 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
2142 if (mapCaptureKey ==
"ByCopy")
2143 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2144 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2145 if (mapCaptureKey ==
"VLAType")
2146 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2147 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
2156 for (
auto mapOp : mapVars) {
2157 if (!mapOp.getDefiningOp())
2160 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2161 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2164 bool from =
mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2167 bool always =
mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2168 bool close =
mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2169 bool implicit =
mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2171 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2173 "to, from, tofrom and alloc map types are permitted");
2175 if (isa<TargetEnterDataOp>(op) && (from || del))
2176 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
2178 if (isa<TargetExitDataOp>(op) && to)
2180 "from, release and delete map types are permitted");
2182 if (isa<TargetUpdateOp>(op)) {
2185 "at least one of to or from map types must be "
2186 "specified, other map types are not permitted");
2191 "at least one of to or from map types must be "
2192 "specified, other map types are not permitted");
2195 auto updateVar = mapInfoOp.getVarPtr();
2197 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2198 (from && updateToVars.contains(updateVar))) {
2201 "either to or from map types can be specified, not both");
2204 if (always || close || implicit) {
2207 "present, mapper and iterator map type modifiers are permitted");
2210 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2212 }
else if (!isa<DeclareMapperInfoOp>(op)) {
2214 "map argument is not a map entry operation");
2222 std::optional<DenseI64ArrayAttr> privateMapIndices =
2223 targetOp.getPrivateMapsAttr();
2226 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2231 if (privateMapIndices.value().size() !=
2232 static_cast<int64_t>(privateVars.size()))
2233 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
2234 "`private_maps` attribute mismatch");
2244 StringRef clauseName,
2246 for (
Value var : vars)
2247 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2249 <<
"'" << clauseName
2250 <<
"' arguments must be defined by 'omp.map.info' ops";
2254LogicalResult MapInfoOp::verify() {
2255 if (getMapperId() &&
2257 *
this, getMapperIdAttr())) {
2272 const TargetDataOperands &clauses) {
2273 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2274 clauses.mapVars, clauses.useDeviceAddrVars,
2275 clauses.useDevicePtrVars);
2278LogicalResult TargetDataOp::verify() {
2279 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2280 getUseDeviceAddrVars().empty()) {
2281 return ::emitError(this->getLoc(),
2282 "At least one of map, use_device_ptr_vars, or "
2283 "use_device_addr_vars operand must be present");
2287 getUseDevicePtrVars())))
2291 getUseDeviceAddrVars())))
2301void TargetEnterDataOp::build(
2305 TargetEnterDataOp::build(
2307 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2308 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2312LogicalResult TargetEnterDataOp::verify() {
2313 LogicalResult verifyDependVars =
2315 getDependIteratedKinds(), getDependIterated());
2316 return failed(verifyDependVars) ? verifyDependVars
2327 TargetExitDataOp::build(
2329 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2330 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2334LogicalResult TargetExitDataOp::verify() {
2335 LogicalResult verifyDependVars =
2337 getDependIteratedKinds(), getDependIterated());
2338 return failed(verifyDependVars) ? verifyDependVars
2349 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
2352 clauses.dependIterated, clauses.device, clauses.ifExpr,
2353 clauses.mapVars, clauses.nowait);
2356LogicalResult TargetUpdateOp::verify() {
2357 LogicalResult verifyDependVars =
2359 getDependIteratedKinds(), getDependIterated());
2360 return failed(verifyDependVars) ? verifyDependVars
2369 const TargetOperands &clauses) {
2374 builder, state, {}, {}, clauses.bare,
2375 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2376 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2377 clauses.device, clauses.hasDeviceAddrVars, clauses.hostEvalVars,
2380 nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2381 clauses.nowait, clauses.privateVars,
2382 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2383 clauses.threadLimitVars,
2387LogicalResult TargetOp::verify() {
2389 getDependIteratedKinds(),
2390 getDependIterated())))
2394 getHasDeviceAddrVars())))
2403LogicalResult TargetOp::verifyRegions() {
2404 auto teamsOps = getOps<TeamsOp>();
2405 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2406 return emitError(
"target containing multiple 'omp.teams' nested ops");
2409 Operation *capturedOp = getInnermostCapturedOmpOp();
2410 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2411 for (
Value hostEvalArg :
2412 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2414 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
2416 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2417 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2418 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2421 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
2422 "and 'thread_limit' in 'omp.teams'";
2424 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
2425 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2426 parallelOp->isAncestor(capturedOp) &&
2427 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2431 <<
"host_eval argument only legal as 'num_threads' in "
2432 "'omp.parallel' when representing target SPMD";
2434 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2435 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2436 loopNestOp.getOperation() == capturedOp &&
2437 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2438 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2439 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2442 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2443 "and steps in 'omp.loop_nest' when trip count "
2444 "must be evaluated in the host";
2447 return emitOpError() <<
"host_eval argument illegal use in '"
2448 << user->getName() <<
"' operation";
2457 assert(rootOp &&
"expected valid operation");
2474 bool isOmpDialect = op->
getDialect() == ompDialect;
2476 if (!isOmpDialect || !hasRegions)
2483 if (checkSingleMandatoryExec) {
2488 if (successor->isReachable(parentBlock))
2491 for (
Block &block : *parentRegion)
2493 !domInfo.
dominates(parentBlock, &block))
2500 if (&sibling != op && !siblingAllowedFn(&sibling))
2513Operation *TargetOp::getInnermostCapturedOmpOp() {
2514 auto *ompDialect =
getContext()->getLoadedDialect<omp::OpenMPDialect>();
2526 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2529 memOp.getEffects(effects);
2530 return !llvm::any_of(
2532 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2533 isa<SideEffects::AutomaticAllocationScopeResource>(
2543 WsloopOp *wsLoopOp) {
2545 if (!teamsOp.getNumTeamsUpperVars().empty())
2549 if (teamsOp.getNumReductionVars())
2551 if (wsLoopOp->getNumReductionVars())
2555 OffloadModuleInterface offloadMod =
2559 auto ompFlags = offloadMod.getFlags();
2562 return ompFlags.getAssumeTeamsOversubscription() &&
2563 ompFlags.getAssumeThreadsOversubscription();
2566TargetRegionFlags TargetOp::getKernelExecFlags(
Operation *capturedOp) {
2571 assert((!capturedOp ||
2572 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2573 "unexpected captured op");
2576 if (!isa_and_present<LoopNestOp>(capturedOp))
2577 return TargetRegionFlags::generic;
2581 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2582 assert(!loopWrappers.empty());
2584 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2585 if (isa<SimdOp>(innermostWrapper))
2586 innermostWrapper = std::next(innermostWrapper);
2588 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2589 if (numWrappers != 1 && numWrappers != 2)
2590 return TargetRegionFlags::generic;
2593 if (numWrappers == 2) {
2594 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2596 return TargetRegionFlags::generic;
2598 innermostWrapper = std::next(innermostWrapper);
2599 if (!isa<DistributeOp>(innermostWrapper))
2600 return TargetRegionFlags::generic;
2603 if (!isa_and_present<ParallelOp>(parallelOp))
2604 return TargetRegionFlags::generic;
2606 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->
getParentOp());
2608 return TargetRegionFlags::generic;
2610 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2611 TargetRegionFlags
result =
2612 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2619 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2621 if (!isa_and_present<TeamsOp>(teamsOp))
2622 return TargetRegionFlags::generic;
2624 if (teamsOp->
getParentOp() != targetOp.getOperation())
2625 return TargetRegionFlags::generic;
2627 if (isa<LoopOp>(innermostWrapper))
2628 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2638 Dialect *ompDialect = targetOp->getDialect();
2642 return sibling && (ompDialect != sibling->
getDialect() ||
2646 TargetRegionFlags
result =
2647 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2652 while (nestedCapture->
getParentOp() != capturedOp)
2655 return isa<ParallelOp>(nestedCapture) ?
result | TargetRegionFlags::spmd
2659 else if (isa<WsloopOp>(innermostWrapper)) {
2661 if (!isa_and_present<ParallelOp>(parallelOp))
2662 return TargetRegionFlags::generic;
2664 if (parallelOp->
getParentOp() == targetOp.getOperation())
2665 return TargetRegionFlags::spmd;
2668 return TargetRegionFlags::generic;
2677 ParallelOp::build(builder, state,
ValueRange(),
2689 const ParallelOperands &clauses) {
2691 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2692 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2694 clauses.privateNeedsBarrier, clauses.procBindKind,
2695 clauses.reductionMod, clauses.reductionVars,
2700template <
typename OpType>
2702 auto privateVars = op.getPrivateVars();
2703 auto privateSyms = op.getPrivateSymsAttr();
2705 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2708 auto numPrivateVars = privateVars.size();
2709 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2711 if (numPrivateVars != numPrivateSyms)
2712 return op.emitError() <<
"inconsistent number of private variables and "
2713 "privatizer op symbols, private vars: "
2715 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2717 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2718 Type varType = std::get<0>(privateVarInfo).getType();
2719 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2720 PrivateClauseOp privatizerOp =
2723 if (privatizerOp ==
nullptr)
2724 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2725 << privateSym <<
"'";
2727 Type privatizerType = privatizerOp.getArgType();
2729 if (privatizerType && (varType != privatizerType))
2730 return op.emitError()
2731 <<
"type mismatch between a "
2732 << (privatizerOp.getDataSharingType() ==
2733 DataSharingClauseType::Private
2736 <<
" variable and its privatizer op, var type: " << varType
2737 <<
" vs. privatizer op type: " << privatizerType;
2743LogicalResult ParallelOp::verify() {
2744 if (getAllocateVars().size() != getAllocatorVars().size())
2746 "expected equal sizes for allocate and allocator variables");
2752 getReductionByref());
2755LogicalResult ParallelOp::verifyRegions() {
2756 auto distChildOps = getOps<DistributeOp>();
2757 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2758 if (numDistChildOps > 1)
2760 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2762 if (numDistChildOps == 1) {
2765 <<
"'omp.composite' attribute missing from composite operation";
2767 auto *ompDialect =
getContext()->getLoadedDialect<OpenMPDialect>();
2768 Operation &distributeOp = **distChildOps.begin();
2770 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2774 return emitError() <<
"unexpected OpenMP operation inside of composite "
2776 << childOp.getName();
2778 }
else if (isComposite()) {
2780 <<
"'omp.composite' attribute present in non-composite operation";
2797 const TeamsOperands &clauses) {
2801 builder, state, clauses.allocateVars, clauses.allocatorVars,
2802 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
2804 nullptr, clauses.reductionMod,
2805 clauses.reductionVars,
2807 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2814 if (numTeamsLower) {
2815 if (numTeamsUpperVars.size() != 1)
2817 "expected exactly one num_teams upper bound when lower bound is "
2821 "expected num_teams upper bound and lower bound to be "
2828LogicalResult TeamsOp::verify() {
2837 return emitError(
"expected to be nested inside of omp.target or not nested "
2838 "in any OpenMP dialect operations");
2842 this->getNumTeamsUpperVars())))
2846 if (getAllocateVars().size() != getAllocatorVars().size())
2848 "expected equal sizes for allocate and allocator variables");
2851 getReductionByref());
2859 return getParentOp().getPrivateVars();
2863 return getParentOp().getReductionVars();
2871 const SectionsOperands &clauses) {
2874 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2877 clauses.reductionMod, clauses.reductionVars,
2882LogicalResult SectionsOp::verify() {
2883 if (getAllocateVars().size() != getAllocatorVars().size())
2885 "expected equal sizes for allocate and allocator variables");
2888 getReductionByref());
2891LogicalResult SectionsOp::verifyRegions() {
2892 for (
auto &inst : *getRegion().begin()) {
2893 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2895 <<
"expected omp.section op or terminator op inside region";
2907 const SingleOperands &clauses) {
2910 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2911 clauses.copyprivateVars,
2912 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2917LogicalResult SingleOp::verify() {
2919 if (getAllocateVars().size() != getAllocatorVars().size())
2921 "expected equal sizes for allocate and allocator variables");
2924 getCopyprivateSyms());
2932 const WorkshareOperands &clauses) {
2933 WorkshareOp::build(builder, state, clauses.nowait);
2940LogicalResult WorkshareLoopWrapperOp::verify() {
2941 if (!(*this)->getParentOfType<WorkshareOp>())
2942 return emitOpError() <<
"must be nested in an omp.workshare";
2946LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2947 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2949 return emitOpError() <<
"expected to be a standalone loop wrapper";
2958LogicalResult LoopWrapperInterface::verifyImpl() {
2962 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2963 "and `SingleBlock` traits";
2966 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2969 if (range_size(region.
getOps()) != 1)
2971 <<
"loop wrapper does not contain exactly one nested op";
2974 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2975 return emitOpError() <<
"nested in loop wrapper is not another loop "
2976 "wrapper or `omp.loop_nest`";
2986 const LoopOperands &clauses) {
2989 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2991 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2992 clauses.reductionMod, clauses.reductionVars,
2997LogicalResult LoopOp::verify() {
2999 getReductionByref());
3002LogicalResult LoopOp::verifyRegions() {
3003 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3005 return emitOpError() <<
"expected to be a standalone loop wrapper";
3016 build(builder, state, {}, {},
3019 false,
nullptr,
nullptr,
3020 nullptr, {},
nullptr,
3031 const WsloopOperands &clauses) {
3036 {}, {}, clauses.linearVars,
3037 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3038 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3039 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3040 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3042 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3043 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3046LogicalResult WsloopOp::verify() {
3050 if (getLinearVars().size() &&
3051 getLinearVarTypes().value().size() != getLinearVars().size())
3052 return emitError() <<
"Ill-formed type attributes for linear variables";
3054 getReductionByref());
3057LogicalResult WsloopOp::verifyRegions() {
3058 bool isCompositeChildLeaf =
3059 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3061 if (LoopWrapperInterface nested = getNestedWrapper()) {
3064 <<
"'omp.composite' attribute missing from composite wrapper";
3068 if (!isa<SimdOp>(nested))
3069 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3071 }
else if (isComposite() && !isCompositeChildLeaf) {
3073 <<
"'omp.composite' attribute present in non-composite wrapper";
3074 }
else if (!isComposite() && isCompositeChildLeaf) {
3076 <<
"'omp.composite' attribute missing from composite wrapper";
3087 const SimdOperands &clauses) {
3089 SimdOp::build(builder, state, clauses.alignedVars,
3091 clauses.linearVars, clauses.linearStepVars,
3092 clauses.linearVarTypes, clauses.linearModifiers,
3093 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3094 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3095 clauses.privateNeedsBarrier, clauses.reductionMod,
3096 clauses.reductionVars,
3102LogicalResult SimdOp::verify() {
3103 if (getSimdlen().has_value() && getSafelen().has_value() &&
3104 getSimdlen().value() > getSafelen().value())
3106 <<
"simdlen clause and safelen clause are both present, but the "
3107 "simdlen value is not less than or equal to safelen value";
3119 bool isCompositeChildLeaf =
3120 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3122 if (!isComposite() && isCompositeChildLeaf)
3124 <<
"'omp.composite' attribute missing from composite wrapper";
3126 if (isComposite() && !isCompositeChildLeaf)
3128 <<
"'omp.composite' attribute present in non-composite wrapper";
3132 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3134 for (
const Attribute &sym : *privateSyms) {
3135 auto symRef = cast<SymbolRefAttr>(sym);
3136 omp::PrivateClauseOp privatizer =
3138 getOperation(), symRef);
3140 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
3141 if (privatizer.getDataSharingType() ==
3142 DataSharingClauseType::FirstPrivate)
3143 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
3147 if (getLinearVars().size() &&
3148 getLinearVarTypes().value().size() != getLinearVars().size())
3149 return emitError() <<
"Ill-formed type attributes for linear variables";
3153LogicalResult SimdOp::verifyRegions() {
3154 if (getNestedWrapper())
3155 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
3165 const DistributeOperands &clauses) {
3166 DistributeOp::build(builder, state, clauses.allocateVars,
3167 clauses.allocatorVars, clauses.distScheduleStatic,
3168 clauses.distScheduleChunkSize, clauses.order,
3169 clauses.orderMod, clauses.privateVars,
3171 clauses.privateNeedsBarrier);
3174LogicalResult DistributeOp::verify() {
3175 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3177 "dist_schedule_static being present";
3179 if (getAllocateVars().size() != getAllocatorVars().size())
3181 "expected equal sizes for allocate and allocator variables");
3186LogicalResult DistributeOp::verifyRegions() {
3187 if (LoopWrapperInterface nested = getNestedWrapper()) {
3190 <<
"'omp.composite' attribute missing from composite wrapper";
3193 if (isa<WsloopOp>(nested)) {
3195 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3196 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3197 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
3198 "when a composite 'omp.parallel' is the direct "
3201 }
else if (!isa<SimdOp>(nested))
3202 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
3204 }
else if (isComposite()) {
3206 <<
"'omp.composite' attribute present in non-composite wrapper";
3216LogicalResult DeclareMapperInfoOp::verify() {
3220LogicalResult DeclareMapperOp::verifyRegions() {
3221 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3222 getRegion().getBlocks().front().getTerminator()))
3223 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
3232LogicalResult DeclareReductionOp::verifyRegions() {
3233 if (!getAllocRegion().empty()) {
3234 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3235 if (yieldOp.getResults().size() != 1 ||
3236 yieldOp.getResults().getTypes()[0] !=
getType())
3237 return emitOpError() <<
"expects alloc region to yield a value "
3238 "of the reduction type";
3242 if (getInitializerRegion().empty())
3243 return emitOpError() <<
"expects non-empty initializer region";
3244 Block &initializerEntryBlock = getInitializerRegion().
front();
3247 if (!getAllocRegion().empty())
3248 return emitOpError() <<
"expects two arguments to the initializer region "
3249 "when an allocation region is used";
3251 if (getAllocRegion().empty())
3252 return emitOpError() <<
"expects one argument to the initializer region "
3253 "when no allocation region is used";
3256 <<
"expects one or two arguments to the initializer region";
3260 if (arg.getType() !=
getType())
3261 return emitOpError() <<
"expects initializer region argument to match "
3262 "the reduction type";
3264 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3265 if (yieldOp.getResults().size() != 1 ||
3266 yieldOp.getResults().getTypes()[0] !=
getType())
3267 return emitOpError() <<
"expects initializer region to yield a value "
3268 "of the reduction type";
3271 if (getReductionRegion().empty())
3272 return emitOpError() <<
"expects non-empty reduction region";
3273 Block &reductionEntryBlock = getReductionRegion().
front();
3278 return emitOpError() <<
"expects reduction region with two arguments of "
3279 "the reduction type";
3280 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3281 if (yieldOp.getResults().size() != 1 ||
3282 yieldOp.getResults().getTypes()[0] !=
getType())
3283 return emitOpError() <<
"expects reduction region to yield a value "
3284 "of the reduction type";
3287 if (!getAtomicReductionRegion().empty()) {
3288 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
3292 return emitOpError() <<
"expects atomic reduction region with two "
3293 "arguments of the same type";
3294 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3297 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
3298 return emitOpError() <<
"expects atomic reduction region arguments to "
3299 "be accumulators containing the reduction type";
3302 if (getCleanupRegion().empty())
3304 Block &cleanupEntryBlock = getCleanupRegion().
front();
3307 return emitOpError() <<
"expects cleanup region with one argument "
3308 "of the reduction type";
3318 const TaskOperands &clauses) {
3321 builder, state, clauses.iterated, clauses.affinityVars,
3322 clauses.allocateVars, clauses.allocatorVars,
3323 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3324 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3325 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3327 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3328 clauses.priority, clauses.privateVars,
3330 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3333LogicalResult TaskOp::verify() {
3334 LogicalResult verifyDependVars =
3336 getDependIteratedKinds(), getDependIterated());
3337 return failed(verifyDependVars)
3340 getInReductionVars(),
3341 getInReductionByref());
3349 const TaskgroupOperands &clauses) {
3351 TaskgroupOp::build(builder, state, clauses.allocateVars,
3352 clauses.allocatorVars, clauses.taskReductionVars,
3357LogicalResult TaskgroupOp::verify() {
3359 getTaskReductionVars(),
3360 getTaskReductionByref());
3368 const TaskloopContextOperands &clauses) {
3370 TaskloopContextOp::build(
3371 builder, state, clauses.allocateVars, clauses.allocatorVars,
3372 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3373 clauses.inReductionVars,
3375 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3376 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3377 clauses.privateVars,
3379 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3384TaskloopWrapperOp TaskloopContextOp::getLoopOp() {
3385 return cast<TaskloopWrapperOp>(
3387 return isa<TaskloopWrapperOp>(op);
3391LogicalResult TaskloopContextOp::verify() {
3392 if (getAllocateVars().size() != getAllocatorVars().size())
3394 "expected equal sizes for allocate and allocator variables");
3396 getReductionVars(), getReductionByref())) ||
3398 getInReductionVars(),
3399 getInReductionByref())))
3402 if (!getReductionVars().empty() && getNogroup())
3403 return emitError(
"if a reduction clause is present on the taskloop "
3404 "directive, the nogroup clause must not be specified");
3405 for (
auto var : getReductionVars()) {
3406 if (llvm::is_contained(getInReductionVars(), var))
3407 return emitError(
"the same list item cannot appear in both a reduction "
3408 "and an in_reduction clause");
3411 if (getGrainsize() && getNumTasks()) {
3413 "the grainsize clause and num_tasks clause are mutually exclusive and "
3414 "may not appear on the same taskloop directive");
3420LogicalResult TaskloopContextOp::verifyRegions() {
3421 Region ®ion = getRegion();
3423 return emitOpError() <<
"expected non-empty region";
3426 return isa<TaskloopWrapperOp>(op);
3430 <<
"expected exactly 1 TaskloopWrapperOp directly nested in "
3432 << count <<
" were found";
3433 TaskloopWrapperOp loopWrapperOp = getLoopOp();
3435 auto loopNestOp = dyn_cast<LoopNestOp>(loopWrapperOp.getWrappedLoop());
3441 std::function<
bool(
Value)> isValidBoundValue = [&](
Value value) ->
bool {
3442 Region *valueRegion = value.getParentRegion();
3448 Operation *defOp = value.getDefiningOp();
3452 return llvm::all_of(defOp->
getOperands(), isValidBoundValue);
3454 auto hasUnsupportedTaskloopLocalBound = [&](
OperandRange range) ->
bool {
3455 return llvm::any_of(range,
3456 [&](
Value value) {
return !isValidBoundValue(value); });
3459 if (hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopLowerBounds()) ||
3460 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopUpperBounds()) ||
3461 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopSteps())) {
3463 <<
"expects loop bounds and steps to be defined outside of the "
3464 "taskloop.context region or by pure, regionless operations "
3465 "that do not depend on block arguments";
3476 const TaskloopWrapperOperands &clauses) {
3477 TaskloopWrapperOp::build(builder, state);
3480TaskloopContextOp TaskloopWrapperOp::getTaskloopContext() {
3481 return dyn_cast<TaskloopContextOp>(getOperation()->getParentOp());
3484LogicalResult TaskloopWrapperOp::verify() {
3485 TaskloopContextOp context = getTaskloopContext();
3487 return emitOpError() <<
"expected to be nested in a taskloop context op";
3491LogicalResult TaskloopWrapperOp::verifyRegions() {
3492 if (LoopWrapperInterface nested = getNestedWrapper()) {
3495 <<
"'omp.composite' attribute missing from composite wrapper";
3499 if (!isa<SimdOp>(nested))
3500 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3501 }
else if (isComposite()) {
3503 <<
"'omp.composite' attribute present in non-composite wrapper";
3527 for (
auto &iv : ivs)
3528 iv.type = loopVarType;
3533 result.addAttribute(
"loop_inclusive", UnitAttr::get(ctx));
3549 "collapse_num_loops",
3554 auto parseTiles = [&]() -> ParseResult {
3558 tiles.push_back(
tile);
3567 if (tiles.size() > 0)
3586 Region ®ion = getRegion();
3588 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3589 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3590 if (getLoopInclusive())
3592 p <<
"step (" << getLoopSteps() <<
") ";
3593 if (
int64_t numCollapse = getCollapseNumLoops())
3594 if (numCollapse > 1)
3595 p <<
"collapse(" << numCollapse <<
") ";
3598 p <<
"tiles(" << tiles.value() <<
") ";
3604 const LoopNestOperands &clauses) {
3606 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3607 clauses.loopLowerBounds, clauses.loopUpperBounds,
3608 clauses.loopSteps, clauses.loopInclusive,
3612LogicalResult LoopNestOp::verify() {
3613 if (getLoopLowerBounds().empty())
3614 return emitOpError() <<
"must represent at least one loop";
3616 if (getLoopLowerBounds().size() != getIVs().size())
3617 return emitOpError() <<
"number of range arguments and IVs do not match";
3619 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3620 if (lb.getType() != iv.getType())
3622 <<
"range argument type does not match corresponding IV type";
3625 uint64_t numIVs = getIVs().size();
3627 if (
const auto &numCollapse = getCollapseNumLoops())
3628 if (numCollapse > numIVs)
3630 <<
"collapse value is larger than the number of loops";
3633 if (tiles.value().size() > numIVs)
3634 return emitOpError() <<
"too few canonical loops for tile dimensions";
3636 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3637 return emitOpError() <<
"expects parent op to be a loop wrapper";
3642void LoopNestOp::gatherWrappers(
3645 while (
auto wrapper =
3646 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3647 wrappers.push_back(wrapper);
3656std::tuple<NewCliOp, OpOperand *, OpOperand *>
3662 return {{},
nullptr,
nullptr};
3665 "Unexpected type of cli");
3671 auto op = cast<LoopTransformationInterface>(use.getOwner());
3673 unsigned opnum = use.getOperandNumber();
3674 if (op.isGeneratee(opnum)) {
3675 assert(!gen &&
"Each CLI may have at most one def");
3677 }
else if (op.isApplyee(opnum)) {
3678 assert(!cons &&
"Each CLI may have at most one consumer");
3681 llvm_unreachable(
"Unexpected operand for a CLI");
3685 return {create, gen, cons};
3708 std::string cliName{
"cli"};
3712 .Case([&](CanonicalLoopOp op) {
3715 .Case([&](UnrollHeuristicOp op) -> std::string {
3716 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3718 .Case([&](FuseOp op) -> std::string {
3719 unsigned opnum =
generator->getOperandNumber();
3722 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3723 return "canonloop_fuse";
3727 .Case([&](TileOp op) -> std::string {
3728 auto [generateesFirst, generateesCount] =
3729 op.getGenerateesODSOperandIndexAndLength();
3730 unsigned firstGrid = generateesFirst;
3731 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3732 unsigned end = generateesFirst + generateesCount;
3733 unsigned opnum =
generator->getOperandNumber();
3735 if (firstGrid <= opnum && opnum < firstIntratile) {
3736 unsigned gridnum = opnum - firstGrid + 1;
3737 return (
"grid" + Twine(gridnum)).str();
3739 if (firstIntratile <= opnum && opnum < end) {
3740 unsigned intratilenum = opnum - firstIntratile + 1;
3741 return (
"intratile" + Twine(intratilenum)).str();
3743 llvm_unreachable(
"Unexpected generatee argument");
3745 .DefaultUnreachable(
"TODO: Custom name for this operation");
3748 setNameFn(
result, cliName);
3751LogicalResult NewCliOp::verify() {
3752 Value cli = getResult();
3755 "Unexpected type of cli");
3761 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3763 unsigned opnum = use.getOperandNumber();
3764 if (op.isGeneratee(opnum)) {
3767 emitOpError(
"CLI must have at most one generator");
3769 .
append(
"first generator here:");
3771 .
append(
"second generator here:");
3776 }
else if (op.isApplyee(opnum)) {
3779 emitOpError(
"CLI must have at most one consumer");
3781 .
append(
"first consumer here:")
3785 .
append(
"second consumer here:")
3792 llvm_unreachable(
"Unexpected operand for a CLI");
3800 .
append(
"see consumer here: ")
3823 setNameFn(&getRegion().front(),
"body_entry");
3826void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
3834 p <<
'(' << getCli() <<
')';
3835 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
3836 <<
" in range(" << getTripCount() <<
") ";
3846 CanonicalLoopInfoType cliType =
3847 CanonicalLoopInfoType::get(parser.
getContext());
3872 if (parser.
parseRegion(*region, {inductionVariable}))
3877 result.operands.append(cliOperand);
3883 return mlir::success();
3886LogicalResult CanonicalLoopOp::verify() {
3889 if (!getRegion().empty()) {
3890 Region ®ion = getRegion();
3893 "Canonical loop region must have exactly one argument");
3897 "Region argument must be the same type as the trip count");
3903Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
3905std::pair<unsigned, unsigned>
3906CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3911std::pair<unsigned, unsigned>
3912CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3913 return getODSOperandIndexAndLength(odsIndex_cli);
3927 p <<
'(' << getApplyee() <<
')';
3934 auto cliType = CanonicalLoopInfoType::get(parser.
getContext());
3957 return mlir::success();
3960std::pair<unsigned, unsigned>
3961UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3962 return getODSOperandIndexAndLength(odsIndex_applyee);
3965std::pair<unsigned, unsigned>
3966UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3977 if (!generatees.empty())
3978 p <<
'(' << llvm::interleaved(generatees) <<
')';
3980 if (!applyees.empty())
3981 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4023 bool isOnlyCanonLoops =
true;
4025 for (
Value applyee : op.getApplyees()) {
4026 auto [create, gen, cons] =
decodeCli(applyee);
4029 return op.emitOpError() <<
"applyee CLI has no generator";
4031 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4032 canonLoops.push_back(loop);
4034 isOnlyCanonLoops =
false;
4039 if (!isOnlyCanonLoops)
4043 for (
auto i : llvm::seq<int>(1, canonLoops.size())) {
4044 auto parentLoop = canonLoops[i - 1];
4045 auto loop = canonLoops[i];
4047 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
4048 return op.emitOpError()
4049 <<
"tiled loop nest must be nested within each other";
4051 parentIVs.insert(parentLoop.getInductionVar());
4056 bool isPerfectlyNested = [&]() {
4057 auto &parentBody = parentLoop.getRegion();
4058 if (!parentBody.hasOneBlock())
4060 auto &parentBlock = parentBody.getBlocks().front();
4062 auto nestedLoopIt = parentBlock.begin();
4063 if (nestedLoopIt == parentBlock.end() ||
4064 (&*nestedLoopIt != loop.getOperation()))
4067 auto termIt = std::next(nestedLoopIt);
4068 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
4071 if (std::next(termIt) != parentBlock.end())
4076 if (!isPerfectlyNested)
4077 return op.emitOpError() <<
"tiled loop nest must be perfectly nested";
4079 if (parentIVs.contains(loop.getTripCount()))
4080 return op.emitOpError() <<
"tiled loop nest must be rectangular";
4097LogicalResult TileOp::verify() {
4098 if (getApplyees().empty())
4099 return emitOpError() <<
"must apply to at least one loop";
4101 if (getSizes().size() != getApplyees().size())
4102 return emitOpError() <<
"there must be one tile size for each applyee";
4104 if (!getGeneratees().empty() &&
4105 2 * getSizes().size() != getGeneratees().size())
4107 <<
"expecting two times the number of generatees than applyees";
4112std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4113 return getODSOperandIndexAndLength(odsIndex_applyees);
4116std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4117 return getODSOperandIndexAndLength(odsIndex_generatees);
4127 if (!generatees.empty())
4128 p <<
'(' << llvm::interleaved(generatees) <<
')';
4130 if (!applyees.empty())
4131 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4134LogicalResult FuseOp::verify() {
4135 if (getApplyees().size() < 2)
4136 return emitOpError() <<
"must apply to at least two loops";
4138 if (getFirst().has_value() && getCount().has_value()) {
4139 int64_t first = getFirst().value();
4140 int64_t count = getCount().value();
4141 if ((
unsigned)(first + count - 1) > getApplyees().size())
4142 return emitOpError() <<
"the numbers of applyees must be at least first "
4143 "minus one plus count attributes";
4144 if (!getGeneratees().empty() &&
4145 getGeneratees().size() != getApplyees().size() + 1 - count)
4146 return emitOpError() <<
"the number of generatees must be the number of "
4147 "aplyees plus one minus count";
4150 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4152 <<
"in a complete fuse the number of generatees must be exactly 1";
4154 for (
auto &&applyee : getApplyees()) {
4155 auto [create, gen, cons] =
decodeCli(applyee);
4158 return emitOpError() <<
"applyee CLI has no generator";
4159 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4162 <<
"currently only supports omp.canonical_loop as applyee";
4166std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4167 return getODSOperandIndexAndLength(odsIndex_applyees);
4170std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4171 return getODSOperandIndexAndLength(odsIndex_generatees);
4179 const CriticalDeclareOperands &clauses) {
4180 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4183LogicalResult CriticalDeclareOp::verify() {
4188 if (getNameAttr()) {
4189 SymbolRefAttr symbolRef = getNameAttr();
4193 return emitOpError() <<
"expected symbol reference " << symbolRef
4194 <<
" to point to a critical declaration";
4214 return op.
emitOpError() <<
"must be nested inside of a loop";
4218 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4219 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4221 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
4222 "have an ordered clause";
4224 if (hasRegion && orderedAttr.getInt() != 0)
4225 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
4226 "have a parameter present";
4228 if (!hasRegion && orderedAttr.getInt() == 0)
4229 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
4230 "have a parameter present";
4231 }
else if (!isa<SimdOp>(wrapper)) {
4232 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
4233 "or worksharing simd loop";
4239 const OrderedOperands &clauses) {
4240 OrderedOp::build(builder, state, clauses.doacrossDependType,
4241 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4244LogicalResult OrderedOp::verify() {
4248 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4249 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4250 return emitOpError() <<
"number of variables in depend clause does not "
4251 <<
"match number of iteration variables in the "
4258 const OrderedRegionOperands &clauses) {
4259 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4269 const TaskwaitOperands &clauses) {
4271 TaskwaitOp::build(builder, state,
nullptr,
4280LogicalResult AtomicReadOp::verify() {
4281 if (verifyCommon().
failed())
4282 return mlir::failure();
4284 if (
auto mo = getMemoryOrder()) {
4285 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4286 *mo == ClauseMemoryOrderKind::Release) {
4288 "memory-order must not be acq_rel or release for atomic reads");
4298LogicalResult AtomicWriteOp::verify() {
4299 if (verifyCommon().
failed())
4300 return mlir::failure();
4302 if (
auto mo = getMemoryOrder()) {
4303 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4304 *mo == ClauseMemoryOrderKind::Acquire) {
4306 "memory-order must not be acq_rel or acquire for atomic writes");
4316LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4322 if (
Value writeVal = op.getWriteOpVal()) {
4324 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4330LogicalResult AtomicUpdateOp::verify() {
4331 if (verifyCommon().
failed())
4332 return mlir::failure();
4334 if (
auto mo = getMemoryOrder()) {
4335 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4336 *mo == ClauseMemoryOrderKind::Acquire) {
4338 "memory-order must not be acq_rel or acquire for atomic updates");
4345LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4351AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4352 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4354 return dyn_cast<AtomicReadOp>(getSecondOp());
4357AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4358 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4360 return dyn_cast<AtomicWriteOp>(getSecondOp());
4363AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4364 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4366 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4369LogicalResult AtomicCaptureOp::verify() {
4373LogicalResult AtomicCaptureOp::verifyRegions() {
4374 if (verifyRegionsCommon().
failed())
4375 return mlir::failure();
4377 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
4379 "operations inside capture region must not have hint clause");
4381 if (getFirstOp()->getAttr(
"memory_order") ||
4382 getSecondOp()->getAttr(
"memory_order"))
4384 "operations inside capture region must not have memory_order clause");
4393 const CancelOperands &clauses) {
4394 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4407LogicalResult CancelOp::verify() {
4408 ClauseCancellationConstructType cct = getCancelDirective();
4411 if (!structuralParent)
4412 return emitOpError() <<
"Orphaned cancel construct";
4414 if ((cct == ClauseCancellationConstructType::Parallel) &&
4415 !mlir::isa<ParallelOp>(structuralParent)) {
4416 return emitOpError() <<
"cancel parallel must appear "
4417 <<
"inside a parallel region";
4419 if (cct == ClauseCancellationConstructType::Loop) {
4422 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
4426 <<
"cancel loop must appear inside a worksharing-loop region";
4428 if (wsloopOp.getNowaitAttr()) {
4429 return emitError() <<
"A worksharing construct that is canceled "
4430 <<
"must not have a nowait clause";
4432 if (wsloopOp.getOrderedAttr()) {
4433 return emitError() <<
"A worksharing construct that is canceled "
4434 <<
"must not have an ordered clause";
4437 }
else if (cct == ClauseCancellationConstructType::Sections) {
4441 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
4443 return emitOpError() <<
"cancel sections must appear "
4444 <<
"inside a sections region";
4446 if (sectionsOp.getNowait()) {
4447 return emitError() <<
"A sections construct that is canceled "
4448 <<
"must not have a nowait clause";
4451 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4452 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4453 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->
getParentOp()))) {
4454 return emitOpError() <<
"cancel taskgroup must appear "
4455 <<
"inside a task region";
4465 const CancellationPointOperands &clauses) {
4466 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4469LogicalResult CancellationPointOp::verify() {
4470 ClauseCancellationConstructType cct = getCancelDirective();
4473 if (!structuralParent)
4474 return emitOpError() <<
"Orphaned cancellation point";
4476 if ((cct == ClauseCancellationConstructType::Parallel) &&
4477 !mlir::isa<ParallelOp>(structuralParent)) {
4478 return emitOpError() <<
"cancellation point parallel must appear "
4479 <<
"inside a parallel region";
4483 if ((cct == ClauseCancellationConstructType::Loop) &&
4484 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
4485 return emitOpError() <<
"cancellation point loop must appear "
4486 <<
"inside a worksharing-loop region";
4488 if ((cct == ClauseCancellationConstructType::Sections) &&
4489 !mlir::isa<omp::SectionOp>(structuralParent)) {
4490 return emitOpError() <<
"cancellation point sections must appear "
4491 <<
"inside a sections region";
4493 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4494 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4495 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->
getParentOp()))) {
4496 return emitOpError() <<
"cancellation point taskgroup must appear "
4497 <<
"inside a task region";
4506LogicalResult MapBoundsOp::verify() {
4507 auto extent = getExtent();
4509 if (!extent && !upperbound)
4510 return emitError(
"expected extent or upperbound.");
4517 PrivateClauseOp::build(
4518 odsBuilder, odsState, symName, type,
4519 DataSharingClauseTypeAttr::get(odsBuilder.
getContext(),
4520 DataSharingClauseType::Private));
4523LogicalResult PrivateClauseOp::verifyRegions() {
4524 Type argType = getArgType();
4525 auto verifyTerminator = [&](
Operation *terminator,
4526 bool yieldsValue) -> LogicalResult {
4530 if (!llvm::isa<YieldOp>(terminator))
4532 <<
"expected exit block terminator to be an `omp.yield` op.";
4534 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4535 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4538 if (yieldedTypes.empty())
4542 <<
"Did not expect any values to be yielded.";
4545 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4549 <<
"Invalid yielded value. Expected type: " << argType
4552 if (yieldedTypes.empty())
4555 error << yieldedTypes;
4561 StringRef regionName,
4562 bool yieldsValue) -> LogicalResult {
4563 assert(!region.
empty());
4567 <<
"`" << regionName <<
"`: "
4568 <<
"expected " << expectedNumArgs
4571 for (
Block &block : region) {
4573 if (!block.mightHaveTerminator())
4576 if (
failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4584 for (
Region *region : getRegions())
4585 for (
Type ty : region->getArgumentTypes())
4587 return emitError() <<
"Region argument type mismatch: got " << ty
4588 <<
" expected " << argType <<
".";
4591 if (!initRegion.
empty() &&
4596 DataSharingClauseType dsType = getDataSharingType();
4598 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4599 return emitError(
"`private` clauses do not require a `copy` region.");
4601 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4603 "`firstprivate` clauses require at least a `copy` region.");
4605 if (dsType == DataSharingClauseType::FirstPrivate &&
4610 if (!getDeallocRegion().empty() &&
4623 const MaskedOperands &clauses) {
4624 MaskedOp::build(builder, state, clauses.filteredThreadId);
4632 const ScanOperands &clauses) {
4633 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4636LogicalResult ScanOp::verify() {
4637 if (hasExclusiveVars() == hasInclusiveVars())
4639 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4640 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4641 if (parentWsLoopOp.getReductionModAttr() &&
4642 parentWsLoopOp.getReductionModAttr().getValue() ==
4643 ReductionModifier::inscan)
4646 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4647 if (parentSimdOp.getReductionModAttr() &&
4648 parentSimdOp.getReductionModAttr().getValue() ==
4649 ReductionModifier::inscan)
4652 return emitError(
"SCAN directive needs to be enclosed within a parent "
4653 "worksharing loop construct or SIMD construct with INSCAN "
4654 "reduction modifier");
4659LogicalResult AllocateDirOp::verify() {
4660 std::optional<uint64_t> align = this->getAlign();
4662 if (align.has_value()) {
4663 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4664 return emitError() <<
"ALIGN value : " << align.value()
4665 <<
" must be power of 2";
4675mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4676 return getInTypeAttr().getValue();
4685 bool hasOperands =
false;
4686 std::int32_t typeparamsSize = 0;
4692 return mlir::failure();
4694 return mlir::failure();
4696 return mlir::failure();
4700 return mlir::failure();
4701 result.addAttribute(
"in_type", mlir::TypeAttr::get(intype));
4708 return mlir::failure();
4709 typeparamsSize = operands.size();
4712 std::int32_t shapeSize = 0;
4716 return mlir::failure();
4717 shapeSize = operands.size() - typeparamsSize;
4719 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4720 typeVec.push_back(idxTy);
4726 return mlir::failure();
4731 return mlir::failure();
4734 result.addAttribute(
"operandSegmentSizes",
4738 return mlir::failure();
4739 return mlir::success();
4754 if (!getTypeparams().empty()) {
4755 p <<
'(' << getTypeparams() <<
" : " << getTypeparams().getTypes() <<
')';
4762 {
"in_type",
"operandSegmentSizes"});
4765llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4767 if (!mlir::dyn_cast<IntegerType>(outType))
4769 return mlir::success();
4776LogicalResult WorkdistributeOp::verify() {
4778 Region ®ion = getRegion();
4783 if (entryBlock.
empty())
4784 return emitOpError(
"region must contain a structured block");
4786 bool hasTerminator =
false;
4787 for (
Block &block : region) {
4788 if (isa<TerminatorOp>(block.back())) {
4789 if (hasTerminator) {
4790 return emitOpError(
"region must have exactly one terminator");
4792 hasTerminator =
true;
4795 if (!hasTerminator) {
4796 return emitOpError(
"region must be terminated with omp.terminator");
4800 if (isa<BarrierOp>(op)) {
4802 "explicit barriers are not allowed in workdistribute region");
4805 if (isa<ParallelOp>(op)) {
4807 "nested parallel constructs not allowed in workdistribute");
4809 if (isa<TeamsOp>(op)) {
4811 "nested teams constructs not allowed in workdistribute");
4815 if (walkResult.wasInterrupted())
4819 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4820 return emitOpError(
"workdistribute must be nested under teams");
4828LogicalResult DeclareSimdOp::verify() {
4831 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4833 return emitOpError() <<
"must be nested inside a function";
4835 if (getInbranch() && getNotinbranch())
4836 return emitOpError(
"cannot have both 'inbranch' and 'notinbranch'");
4846 const DeclareSimdOperands &clauses) {
4848 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4850 clauses.linearVars, clauses.linearStepVars,
4851 clauses.linearVarTypes, clauses.linearModifiers,
4852 clauses.notinbranch, clauses.simdlen,
4853 clauses.uniformVars);
4870 return mlir::failure();
4871 return mlir::success();
4878 for (
unsigned i = 0; i < uniformVars.size(); ++i) {
4881 p << uniformVars[i] <<
" : " << uniformTypes[i];
4896 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
4897 [&]() -> ParseResult {
return success(); })))
4931 OpAsmParser::Argument &arg = ivArgs.emplace_back();
4932 if (parser.parseArgument(arg))
4936 if (succeeded(parser.parseOptionalColon())) {
4937 if (parser.parseType(arg.type))
4940 arg.type = parser.getBuilder().getIndexType();
4952 OpAsmParser::UnresolvedOperand lb, ub, st;
4953 if (parser.parseOperand(lb) || parser.parseKeyword(
"to") ||
4954 parser.parseOperand(ub) || parser.parseKeyword(
"step") ||
4955 parser.parseOperand(st))
4960 steps.push_back(st);
4968 if (ivArgs.size() != lbs.size())
4970 <<
"mismatch: " << ivArgs.size() <<
" variables but " << lbs.size()
4973 for (
auto &arg : ivArgs) {
4974 lbTypes.push_back(arg.type);
4975 ubTypes.push_back(arg.type);
4976 stepTypes.push_back(arg.type);
4996 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
4999 p << lbs[i] <<
" to " << ubs[i] <<
" step " << steps[i];
5007LogicalResult IteratorOp::verify() {
5008 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().
getType());
5010 return emitOpError() <<
"result must be omp.iterated<entry_ty>";
5012 for (
auto [lb,
ub, step] : llvm::zip_equal(
5013 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
5015 return emitOpError() <<
"loop step must not be zero";
5019 IntegerAttr stepAttr;
5025 const APInt &lbVal = lbAttr.getValue();
5026 const APInt &ubVal = ubAttr.getValue();
5027 const APInt &stepVal = stepAttr.getValue();
5028 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
5029 return emitOpError() <<
"positive loop step requires lower bound to be "
5030 "less than or equal to upper bound";
5031 if (stepVal.isNegative() && lbVal.slt(ubVal))
5032 return emitOpError() <<
"negative loop step requires lower bound to be "
5033 "greater than or equal to upper bound";
5036 Block &
b = getRegion().front();
5037 auto yield = llvm::dyn_cast<omp::YieldOp>(
b.getTerminator());
5040 return emitOpError() <<
"region must be terminated by omp.yield";
5042 if (yield.getNumOperands() != 1)
5044 <<
"omp.yield in omp.iterator region must yield exactly one value";
5046 mlir::Type yieldedTy = yield.getOperand(0).getType();
5047 mlir::Type elemTy = iteratedTy.getElementType();
5049 if (yieldedTy != elemTy)
5050 return emitOpError() <<
"omp.iterated element type (" << elemTy
5051 <<
") does not match omp.yield operand type ("
5052 << yieldedTy <<
")";
5057#define GET_ATTRDEF_CLASSES
5058#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
5060#define GET_OP_CLASSES
5061#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
5063#define GET_TYPEDEF_CLASSES
5064#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
operand_range getOperands()
Returns an iterator on the underlying Value's.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.