26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/PostOrderIterator.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/STLForwardCompat.h"
30#include "llvm/ADT/SmallString.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/ADT/bit.h"
35#include "llvm/Support/InterleavedRange.h"
41#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
51 return attrs.empty() ?
nullptr : ArrayAttr::get(context, attrs);
65struct MemRefPointerLikeModel
66 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
69 return llvm::cast<MemRefType>(pointer).getElementType();
73struct LLVMPointerPointerLikeModel
74 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
75 LLVM::LLVMPointerType> {
100 bool isRegionArgOfOp;
110 assert(isRegionArgOfOp &&
"Must describe a region operand");
113 size_t &getArgIdx() {
114 assert(isRegionArgOfOp &&
"Must describe a region operand");
119 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
123 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
126 bool isLoopOp()
const {
127 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
128 return isa<CanonicalLoopOp>(op);
130 Region *&getParentRegion() {
131 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
134 size_t &getLoopDepth() {
135 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
139 void skipIf(
bool v =
true) { skip = skip || v; }
157 llvm::ReversePostOrderTraversal<Block *> traversal(&r->
getBlocks().front());
160 size_t sequentialIdx = -1;
161 bool isOnlyContainerOp =
true;
162 for (
Block *
b : traversal) {
164 if (&op == o && !found) {
168 if (op.getNumRegions()) {
171 isOnlyContainerOp =
false;
173 if (found && !isOnlyContainerOp)
178 Component &containerOpInRegion = components.emplace_back();
179 containerOpInRegion.isRegionArgOfOp =
false;
180 containerOpInRegion.isUnique = isOnlyContainerOp;
181 containerOpInRegion.getContainerOp() = o;
182 containerOpInRegion.getOpPos() = sequentialIdx;
183 containerOpInRegion.getParentRegion() = r;
188 Component ®ionArgOfOperation = components.emplace_back();
189 regionArgOfOperation.isRegionArgOfOp =
true;
190 regionArgOfOperation.isUnique =
true;
191 regionArgOfOperation.getArgIdx() = 0;
192 regionArgOfOperation.getOwnerOp() = parent;
204 for (
auto [idx, region] : llvm::enumerate(o->
getRegions())) {
208 llvm_unreachable(
"Region not child of its parent operation");
210 regionArgOfOperation.isUnique =
false;
211 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
219 for (Component &c : components)
220 c.skipIf(c.isRegionArgOfOp && c.isUnique);
223 size_t numSurroundingLoops = 0;
224 for (Component &c : llvm::reverse(components)) {
229 if (c.isRegionArgOfOp) {
230 numSurroundingLoops = 0;
237 numSurroundingLoops = 0;
239 c.getLoopDepth() = numSurroundingLoops;
242 if (isa<CanonicalLoopOp>(c.getContainerOp()))
243 numSurroundingLoops += 1;
248 bool isLoopNest =
false;
249 for (Component &c : components) {
250 if (c.skip || c.isRegionArgOfOp)
253 if (!isLoopNest && c.getLoopDepth() >= 1) {
256 }
else if (isLoopNest) {
258 c.skipIf(c.isUnique);
262 if (c.getLoopDepth() == 0)
269 for (Component &c : components)
270 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
271 !isa<CanonicalLoopOp>(c.getContainerOp()));
275 bool newRegion =
true;
276 for (Component &c : llvm::reverse(components)) {
277 c.skipIf(newRegion && c.isUnique);
284 if (!c.isRegionArgOfOp && c.getContainerOp())
290 llvm::raw_svector_ostream NameOS(Name);
291 for (
auto &c : llvm::reverse(components)) {
295 if (c.isRegionArgOfOp)
296 NameOS <<
"_r" << c.getArgIdx();
297 else if (c.getLoopDepth() >= 1)
298 NameOS <<
"_d" << c.getLoopDepth();
300 NameOS <<
"_s" << c.getOpPos();
303 return NameOS.str().str();
306void OpenMPDialect::initialize() {
309#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
312#define GET_ATTRDEF_LIST
313#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
316#define GET_TYPEDEF_LIST
317#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
320 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
322 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
323 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
328 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
334 mlir::LLVM::GlobalOp::attachInterface<
337 mlir::LLVM::LLVMFuncOp::attachInterface<
340 mlir::func::FuncOp::attachInterface<
366 allocatorVars.push_back(operand);
367 allocatorTypes.push_back(type);
373 allocateVars.push_back(operand);
374 allocateTypes.push_back(type);
385 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
386 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
387 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
388 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
396template <
typename ClauseAttr>
398 using ClauseT =
decltype(std::declval<ClauseAttr>().getValue());
403 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
404 attr = ClauseAttr::get(parser.
getContext(), *enumValue);
407 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
410template <
typename ClauseAttr>
412 p << stringifyEnum(attr.getValue());
435 linearVars.push_back(var);
436 linearTypes.push_back(type);
437 linearStepVars.push_back(stepVar);
446 size_t linearVarsSize = linearVars.size();
447 for (
unsigned i = 0; i < linearVarsSize; ++i) {
448 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
450 if (linearStepVars.size() > i)
451 p <<
" = " << linearStepVars[i];
452 p <<
" : " << linearVars[i].getType() << separator;
465 for (
const auto &it : nontemporalVars)
466 if (!nontemporalItems.insert(it).second)
467 return op->
emitOpError() <<
"nontemporal variable used more than once";
476 std::optional<ArrayAttr> alignments,
479 if (!alignedVars.empty()) {
480 if (!alignments || alignments->size() != alignedVars.size())
482 <<
"expected as many alignment values as aligned variables";
485 return op->
emitOpError() <<
"unexpected alignment values attribute";
491 for (
auto it : alignedVars)
492 if (!alignedItems.insert(it).second)
493 return op->
emitOpError() <<
"aligned variable used more than once";
499 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
500 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
501 if (intAttr.getValue().sle(0))
502 return op->
emitOpError() <<
"alignment should be greater than 0";
504 return op->
emitOpError() <<
"expected integer alignment";
521 if (parser.parseOperand(alignedVars.emplace_back()) ||
522 parser.parseColonType(alignedTypes.emplace_back()) ||
523 parser.parseArrow() ||
524 parser.parseAttribute(alignmentVec.emplace_back())) {
531 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
538 std::optional<ArrayAttr> alignments) {
539 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
542 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
543 p <<
" -> " << (*alignments)[i];
554 if (modifiers.size() > 2)
556 for (
const auto &mod : modifiers) {
559 auto symbol = symbolizeScheduleModifier(mod);
562 <<
" unknown modifier type: " << mod;
567 if (modifiers.size() == 1) {
568 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
569 modifiers.push_back(modifiers[0]);
570 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
572 }
else if (modifiers.size() == 2) {
575 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
576 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
578 <<
" incorrect modifier order";
594 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
595 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
600 std::optional<mlir::omp::ClauseScheduleKind> schedule =
601 symbolizeClauseScheduleKind(keyword);
605 scheduleAttr = ClauseScheduleKindAttr::get(parser.
getContext(), *schedule);
607 case ClauseScheduleKind::Static:
608 case ClauseScheduleKind::Dynamic:
609 case ClauseScheduleKind::Guided:
615 chunkSize = std::nullopt;
618 case ClauseScheduleKind::Auto:
619 case ClauseScheduleKind::Runtime:
620 chunkSize = std::nullopt;
629 modifiers.push_back(mod);
635 if (!modifiers.empty()) {
637 if (std::optional<ScheduleModifier> mod =
638 symbolizeScheduleModifier(modifiers[0])) {
639 scheduleMod = ScheduleModifierAttr::get(parser.
getContext(), *mod);
641 return parser.
emitError(loc,
"invalid schedule modifier");
644 if (modifiers.size() > 1) {
645 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
655 ClauseScheduleKindAttr scheduleKind,
656 ScheduleModifierAttr scheduleMod,
657 UnitAttr scheduleSimd,
Value scheduleChunk,
658 Type scheduleChunkType) {
659 p << stringifyClauseScheduleKind(scheduleKind.getValue());
661 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
663 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
675 ClauseOrderKindAttr &order,
676 OrderModifierAttr &orderMod) {
681 if (std::optional<OrderModifier> enumValue =
682 symbolizeOrderModifier(enumStr)) {
683 orderMod = OrderModifierAttr::get(parser.
getContext(), *enumValue);
690 if (std::optional<ClauseOrderKind> enumValue =
691 symbolizeClauseOrderKind(enumStr)) {
692 order = ClauseOrderKindAttr::get(parser.
getContext(), *enumValue);
695 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
699 ClauseOrderKindAttr order,
700 OrderModifierAttr orderMod) {
702 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
704 p << stringifyClauseOrderKind(order.getValue());
707template <
typename ClauseTypeAttr,
typename ClauseType>
710 std::optional<OpAsmParser::UnresolvedOperand> &operand,
712 std::optional<ClauseType> (*symbolizeClause)(StringRef),
713 StringRef clauseName) {
716 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
717 prescriptiveness = ClauseTypeAttr::get(parser.
getContext(), *enumValue);
722 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
732 <<
"expected " << clauseName <<
" operand";
735 if (operand.has_value()) {
743template <
typename ClauseTypeAttr,
typename ClauseType>
746 ClauseTypeAttr prescriptiveness,
Value operand,
748 StringRef (*stringifyClauseType)(ClauseType)) {
750 if (prescriptiveness)
751 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
754 p << operand <<
": " << operandType;
764 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
765 Type &grainsizeType) {
767 parser, grainsizeMod, grainsize, grainsizeType,
768 &symbolizeClauseGrainsizeType,
"grainsize");
772 ClauseGrainsizeTypeAttr grainsizeMod,
775 p, op, grainsizeMod, grainsize, grainsizeType,
776 &stringifyClauseGrainsizeType);
786 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
787 Type &numTasksType) {
789 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
794 ClauseNumTasksTypeAttr numTasksMod,
797 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
806 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
807 SmallVectorImpl<Type> &types;
808 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
809 SmallVectorImpl<Type> &types)
810 : vars(vars), types(types) {}
812struct PrivateParseArgs {
813 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
814 llvm::SmallVectorImpl<Type> &types;
816 UnitAttr &needsBarrier;
818 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
819 SmallVectorImpl<Type> &types,
ArrayAttr &syms,
820 UnitAttr &needsBarrier,
822 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
823 mapIndices(mapIndices) {}
826struct ReductionParseArgs {
827 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
828 SmallVectorImpl<Type> &types;
831 ReductionModifierAttr *modifier;
832 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
834 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
835 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
838struct AllRegionParseArgs {
839 std::optional<MapParseArgs> hasDeviceAddrArgs;
840 std::optional<MapParseArgs> hostEvalArgs;
841 std::optional<ReductionParseArgs> inReductionArgs;
842 std::optional<MapParseArgs> mapArgs;
843 std::optional<PrivateParseArgs> privateArgs;
844 std::optional<ReductionParseArgs> reductionArgs;
845 std::optional<ReductionParseArgs> taskReductionArgs;
846 std::optional<MapParseArgs> useDeviceAddrArgs;
847 std::optional<MapParseArgs> useDevicePtrArgs;
852 return "private_barrier";
862 ReductionModifierAttr *modifier =
nullptr,
863 UnitAttr *needsBarrier =
nullptr) {
867 unsigned regionArgOffset = regionPrivateArgs.size();
877 std::optional<ReductionModifier> enumValue =
878 symbolizeReductionModifier(enumStr);
879 if (!enumValue.has_value())
881 *modifier = ReductionModifierAttr::get(parser.
getContext(), *enumValue);
888 isByRefVec.push_back(
889 parser.parseOptionalKeyword(
"byref").succeeded());
891 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
894 if (parser.parseOperand(operands.emplace_back()) ||
895 parser.parseArrow() ||
896 parser.parseArgument(regionPrivateArgs.emplace_back()))
900 if (parser.parseOptionalLSquare().succeeded()) {
901 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
902 parser.parseInteger(mapIndicesVec.emplace_back()) ||
903 parser.parseRSquare())
906 mapIndicesVec.push_back(-1);
918 if (parser.parseType(types.emplace_back()))
925 if (operands.size() != types.size())
934 *needsBarrier = mlir::UnitAttr::get(parser.
getContext());
937 auto *argsBegin = regionPrivateArgs.begin();
939 argsBegin + regionArgOffset + types.size());
940 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
946 *symbols = ArrayAttr::get(parser.
getContext(), symbolAttrs);
949 if (!mapIndicesVec.empty())
962 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
977 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
983 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
984 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
985 nullptr, &privateArgs->needsBarrier)))
994 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
999 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1000 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
1001 reductionArgs->modifier)))
1008 AllRegionParseArgs args) {
1012 args.hasDeviceAddrArgs)))
1014 <<
"invalid `has_device_addr` format";
1017 args.hostEvalArgs)))
1019 <<
"invalid `host_eval` format";
1022 args.inReductionArgs)))
1024 <<
"invalid `in_reduction` format";
1029 <<
"invalid `map_entries` format";
1034 <<
"invalid `private` format";
1037 args.reductionArgs)))
1039 <<
"invalid `reduction` format";
1042 args.taskReductionArgs)))
1044 <<
"invalid `task_reduction` format";
1047 args.useDeviceAddrArgs)))
1049 <<
"invalid `use_device_addr` format";
1052 args.useDevicePtrArgs)))
1054 <<
"invalid `use_device_addr` format";
1056 return parser.
parseRegion(region, entryBlockArgs);
1075 AllRegionParseArgs args;
1076 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1077 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1078 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1079 inReductionByref, inReductionSyms);
1080 args.mapArgs.emplace(mapVars, mapTypes);
1081 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1082 privateNeedsBarrier, &privateMaps);
1093 UnitAttr &privateNeedsBarrier) {
1094 AllRegionParseArgs args;
1095 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1096 inReductionByref, inReductionSyms);
1097 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1098 privateNeedsBarrier);
1109 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1113 AllRegionParseArgs args;
1114 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1115 inReductionByref, inReductionSyms);
1116 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1117 privateNeedsBarrier);
1118 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1119 reductionSyms, &reductionMod);
1127 UnitAttr &privateNeedsBarrier) {
1128 AllRegionParseArgs args;
1129 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1130 privateNeedsBarrier);
1138 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1142 AllRegionParseArgs args;
1143 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1144 privateNeedsBarrier);
1145 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1146 reductionSyms, &reductionMod);
1155 AllRegionParseArgs args;
1156 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1157 taskReductionByref, taskReductionSyms);
1167 AllRegionParseArgs args;
1168 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1169 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1178struct MapPrintArgs {
1183struct PrivatePrintArgs {
1187 UnitAttr needsBarrier;
1191 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1192 mapIndices(mapIndices) {}
1194struct ReductionPrintArgs {
1199 ReductionModifierAttr modifier;
1201 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
1202 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1204struct AllRegionPrintArgs {
1205 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1206 std::optional<MapPrintArgs> hostEvalArgs;
1207 std::optional<ReductionPrintArgs> inReductionArgs;
1208 std::optional<MapPrintArgs> mapArgs;
1209 std::optional<PrivatePrintArgs> privateArgs;
1210 std::optional<ReductionPrintArgs> reductionArgs;
1211 std::optional<ReductionPrintArgs> taskReductionArgs;
1212 std::optional<MapPrintArgs> useDeviceAddrArgs;
1213 std::optional<MapPrintArgs> useDevicePtrArgs;
1222 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
1223 if (argsSubrange.empty())
1226 p << clauseName <<
"(";
1229 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1233 symbols = ArrayAttr::get(ctx, values);
1246 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1247 mapIndices.asArrayRef(),
1248 byref.asArrayRef()),
1250 auto [op, arg, sym, map, isByRef] = t;
1256 p << op <<
" -> " << arg;
1259 p <<
" [map_idx=" << map <<
"]";
1262 llvm::interleaveComma(types, p);
1270 StringRef clauseName,
ValueRange argsSubrange,
1271 std::optional<MapPrintArgs> mapArgs) {
1278 StringRef clauseName,
ValueRange argsSubrange,
1279 std::optional<PrivatePrintArgs> privateArgs) {
1282 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1283 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1284 nullptr, privateArgs->needsBarrier);
1290 std::optional<ReductionPrintArgs> reductionArgs) {
1293 reductionArgs->vars, reductionArgs->types,
1294 reductionArgs->syms,
nullptr,
1295 reductionArgs->byref, reductionArgs->modifier);
1299 const AllRegionPrintArgs &args) {
1300 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1304 iface.getHasDeviceAddrBlockArgs(),
1305 args.hasDeviceAddrArgs);
1309 args.inReductionArgs);
1315 args.reductionArgs);
1317 iface.getTaskReductionBlockArgs(),
1318 args.taskReductionArgs);
1320 iface.getUseDeviceAddrBlockArgs(),
1321 args.useDeviceAddrArgs);
1323 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1339 AllRegionPrintArgs args;
1340 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1341 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1342 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1343 inReductionByref, inReductionSyms);
1344 args.mapArgs.emplace(mapVars, mapTypes);
1345 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1346 privateNeedsBarrier, privateMaps);
1354 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1355 AllRegionPrintArgs args;
1356 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1357 inReductionByref, inReductionSyms);
1358 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1359 privateNeedsBarrier,
1368 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1369 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1372 AllRegionPrintArgs args;
1373 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1374 inReductionByref, inReductionSyms);
1375 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1376 privateNeedsBarrier,
1378 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1379 reductionSyms, reductionMod);
1386 UnitAttr privateNeedsBarrier) {
1387 AllRegionPrintArgs args;
1388 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1389 privateNeedsBarrier,
1397 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1400 AllRegionPrintArgs args;
1401 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1402 privateNeedsBarrier,
1404 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1405 reductionSyms, reductionMod);
1415 AllRegionPrintArgs args;
1416 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1417 taskReductionByref, taskReductionSyms);
1427 AllRegionPrintArgs args;
1428 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1429 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1438 if (!reductionVars.empty()) {
1439 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1441 <<
"expected as many reduction symbol references "
1442 "as reduction variables";
1443 if (reductionByref && reductionByref->size() != reductionVars.size())
1444 return op->
emitError() <<
"expected as many reduction variable by "
1445 "reference attributes as reduction variables";
1448 return op->
emitOpError() <<
"unexpected reduction symbol references";
1455 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1456 Value accum = std::get<0>(args);
1458 if (!accumulators.insert(accum).second)
1459 return op->
emitOpError() <<
"accumulator variable used more than once";
1462 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1466 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1467 <<
" to point to a reduction declaration";
1469 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1471 <<
"expected accumulator (" << varType
1472 <<
") to be the same type as reduction declaration ("
1473 << decl.getAccumulatorType() <<
")";
1492 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1493 parser.parseArrow() ||
1494 parser.parseAttribute(symsVec.emplace_back()) ||
1495 parser.parseColonType(copyprivateTypes.emplace_back()))
1501 copyprivateSyms = ArrayAttr::get(parser.
getContext(), syms);
1509 std::optional<ArrayAttr> copyprivateSyms) {
1510 if (!copyprivateSyms.has_value())
1512 llvm::interleaveComma(
1513 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1514 [&](
const auto &args) {
1515 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1516 << std::get<2>(args);
1523 std::optional<ArrayAttr> copyprivateSyms) {
1524 size_t copyprivateSymsSize =
1525 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1526 if (copyprivateSymsSize != copyprivateVars.size())
1527 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1528 << copyprivateVars.size()
1529 <<
") and functions (= " << copyprivateSymsSize
1530 <<
"), both must be equal";
1531 if (!copyprivateSyms.has_value())
1534 for (
auto copyprivateVarAndSym :
1535 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1537 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1538 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1540 if (mlir::func::FuncOp mlirFuncOp =
1543 funcOp = mlirFuncOp;
1544 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1547 funcOp = llvmFuncOp;
1549 auto getNumArguments = [&] {
1550 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1553 auto getArgumentType = [&](
unsigned i) {
1554 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1559 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1560 <<
" to point to a copy function";
1562 if (getNumArguments() != 2)
1564 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1566 Type argTy = getArgumentType(0);
1567 if (argTy != getArgumentType(1))
1568 return op->
emitOpError() <<
"expected copy function " << symbolRef
1569 <<
" arguments to have the same type";
1571 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1572 if (argTy != varType)
1574 <<
"expected copy function arguments' type (" << argTy
1575 <<
") to be the same as copyprivate variable's type (" << varType
1596 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1597 parser.parseOperand(dependVars.emplace_back()) ||
1598 parser.parseColonType(dependTypes.emplace_back()))
1600 if (std::optional<ClauseTaskDepend> keywordDepend =
1601 (symbolizeClauseTaskDepend(keyword)))
1602 kindsVec.emplace_back(
1603 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1610 dependKinds = ArrayAttr::get(parser.
getContext(), kinds);
1617 std::optional<ArrayAttr> dependKinds) {
1619 for (
unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1622 p << stringifyClauseTaskDepend(
1623 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1625 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
1631 std::optional<ArrayAttr> dependKinds,
1633 if (!dependVars.empty()) {
1634 if (!dependKinds || dependKinds->size() != dependVars.size())
1635 return op->
emitOpError() <<
"expected as many depend values"
1636 " as depend variables";
1638 if (dependKinds && !dependKinds->empty())
1639 return op->
emitOpError() <<
"unexpected depend values";
1655 IntegerAttr &hintAttr) {
1656 StringRef hintKeyword;
1662 auto parseKeyword = [&]() -> ParseResult {
1665 if (hintKeyword ==
"uncontended")
1667 else if (hintKeyword ==
"contended")
1669 else if (hintKeyword ==
"nonspeculative")
1671 else if (hintKeyword ==
"speculative")
1675 << hintKeyword <<
" is not a valid hint";
1686 IntegerAttr hintAttr) {
1687 int64_t hint = hintAttr.getInt();
1695 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1697 bool uncontended = bitn(hint, 0);
1698 bool contended = bitn(hint, 1);
1699 bool nonspeculative = bitn(hint, 2);
1700 bool speculative = bitn(hint, 3);
1704 hints.push_back(
"uncontended");
1706 hints.push_back(
"contended");
1708 hints.push_back(
"nonspeculative");
1710 hints.push_back(
"speculative");
1712 llvm::interleaveComma(hints, p);
1719 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1721 bool uncontended = bitn(hint, 0);
1722 bool contended = bitn(hint, 1);
1723 bool nonspeculative = bitn(hint, 2);
1724 bool speculative = bitn(hint, 3);
1726 if (uncontended && contended)
1727 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1728 "omp_sync_hint_contended cannot be combined";
1729 if (nonspeculative && speculative)
1730 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1731 "omp_sync_hint_speculative cannot be combined.";
1742 return (value & flag) == flag;
1750static ParseResult parseMapClause(
OpAsmParser &parser,
1751 ClauseMapFlagsAttr &mapType) {
1752 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1755 auto parseTypeAndMod = [&]() -> ParseResult {
1756 StringRef mapTypeMod;
1760 if (mapTypeMod ==
"always")
1761 mapTypeBits |= ClauseMapFlags::always;
1763 if (mapTypeMod ==
"implicit")
1764 mapTypeBits |= ClauseMapFlags::implicit;
1766 if (mapTypeMod ==
"ompx_hold")
1767 mapTypeBits |= ClauseMapFlags::ompx_hold;
1769 if (mapTypeMod ==
"close")
1770 mapTypeBits |= ClauseMapFlags::close;
1772 if (mapTypeMod ==
"present")
1773 mapTypeBits |= ClauseMapFlags::present;
1775 if (mapTypeMod ==
"to")
1776 mapTypeBits |= ClauseMapFlags::to;
1778 if (mapTypeMod ==
"from")
1779 mapTypeBits |= ClauseMapFlags::from;
1781 if (mapTypeMod ==
"tofrom")
1782 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1784 if (mapTypeMod ==
"delete")
1785 mapTypeBits |= ClauseMapFlags::del;
1787 if (mapTypeMod ==
"storage")
1788 mapTypeBits |= ClauseMapFlags::storage;
1790 if (mapTypeMod ==
"return_param")
1791 mapTypeBits |= ClauseMapFlags::return_param;
1793 if (mapTypeMod ==
"private")
1794 mapTypeBits |= ClauseMapFlags::priv;
1796 if (mapTypeMod ==
"literal")
1797 mapTypeBits |= ClauseMapFlags::literal;
1799 if (mapTypeMod ==
"attach")
1800 mapTypeBits |= ClauseMapFlags::attach;
1802 if (mapTypeMod ==
"attach_always")
1803 mapTypeBits |= ClauseMapFlags::attach_always;
1805 if (mapTypeMod ==
"attach_none")
1806 mapTypeBits |= ClauseMapFlags::attach_none;
1808 if (mapTypeMod ==
"attach_auto")
1809 mapTypeBits |= ClauseMapFlags::attach_auto;
1811 if (mapTypeMod ==
"ref_ptr")
1812 mapTypeBits |= ClauseMapFlags::ref_ptr;
1814 if (mapTypeMod ==
"ref_ptee")
1815 mapTypeBits |= ClauseMapFlags::ref_ptee;
1817 if (mapTypeMod ==
"ref_ptr_ptee")
1818 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1835 ClauseMapFlagsAttr mapType) {
1837 ClauseMapFlags mapFlags = mapType.getValue();
1842 mapTypeStrs.push_back(
"always");
1844 mapTypeStrs.push_back(
"implicit");
1846 mapTypeStrs.push_back(
"ompx_hold");
1848 mapTypeStrs.push_back(
"close");
1850 mapTypeStrs.push_back(
"present");
1859 mapTypeStrs.push_back(
"tofrom");
1861 mapTypeStrs.push_back(
"from");
1863 mapTypeStrs.push_back(
"to");
1866 mapTypeStrs.push_back(
"delete");
1868 mapTypeStrs.push_back(
"return_param");
1870 mapTypeStrs.push_back(
"storage");
1872 mapTypeStrs.push_back(
"private");
1874 mapTypeStrs.push_back(
"literal");
1876 mapTypeStrs.push_back(
"attach");
1878 mapTypeStrs.push_back(
"attach_always");
1880 mapTypeStrs.push_back(
"attach_none");
1882 mapTypeStrs.push_back(
"attach_auto");
1884 mapTypeStrs.push_back(
"ref_ptr");
1886 mapTypeStrs.push_back(
"ref_ptee");
1888 mapTypeStrs.push_back(
"ref_ptr_ptee");
1889 if (mapFlags == ClauseMapFlags::none)
1890 mapTypeStrs.push_back(
"none");
1892 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1893 p << mapTypeStrs[i];
1894 if (i + 1 < mapTypeStrs.size()) {
1900static ParseResult parseMembersIndex(
OpAsmParser &parser,
1904 auto parseIndices = [&]() -> ParseResult {
1909 APInt(64, value,
false)));
1923 memberIdxs.push_back(ArrayAttr::get(parser.
getContext(), values));
1927 if (!memberIdxs.empty())
1928 membersIdx = ArrayAttr::get(parser.
getContext(), memberIdxs);
1938 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
1940 auto memberIdx = cast<ArrayAttr>(v);
1941 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
1942 p << cast<IntegerAttr>(v2).getInt();
1949 VariableCaptureKindAttr mapCaptureType) {
1950 std::string typeCapStr;
1951 llvm::raw_string_ostream typeCap(typeCapStr);
1952 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1954 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1955 typeCap <<
"ByCopy";
1956 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1957 typeCap <<
"VLAType";
1958 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1964 VariableCaptureKindAttr &mapCaptureType) {
1965 StringRef mapCaptureKey;
1969 if (mapCaptureKey ==
"This")
1970 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1971 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
1972 if (mapCaptureKey ==
"ByRef")
1973 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1974 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
1975 if (mapCaptureKey ==
"ByCopy")
1976 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1977 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1978 if (mapCaptureKey ==
"VLAType")
1979 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1980 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
1989 for (
auto mapOp : mapVars) {
1990 if (!mapOp.getDefiningOp())
1993 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
1994 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
1997 bool from =
mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2000 bool always =
mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2001 bool close =
mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2002 bool implicit =
mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2004 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2006 "to, from, tofrom and alloc map types are permitted");
2008 if (isa<TargetEnterDataOp>(op) && (from || del))
2009 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
2011 if (isa<TargetExitDataOp>(op) && to)
2013 "from, release and delete map types are permitted");
2015 if (isa<TargetUpdateOp>(op)) {
2018 "at least one of to or from map types must be "
2019 "specified, other map types are not permitted");
2024 "at least one of to or from map types must be "
2025 "specified, other map types are not permitted");
2028 auto updateVar = mapInfoOp.getVarPtr();
2030 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2031 (from && updateToVars.contains(updateVar))) {
2034 "either to or from map types can be specified, not both");
2037 if (always || close || implicit) {
2040 "present, mapper and iterator map type modifiers are permitted");
2043 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2045 }
else if (!isa<DeclareMapperInfoOp>(op)) {
2047 "map argument is not a map entry operation");
2055 std::optional<DenseI64ArrayAttr> privateMapIndices =
2056 targetOp.getPrivateMapsAttr();
2059 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2064 if (privateMapIndices.value().size() !=
2065 static_cast<int64_t>(privateVars.size()))
2066 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
2067 "`private_maps` attribute mismatch");
2077 StringRef clauseName,
2079 for (
Value var : vars)
2080 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2082 <<
"'" << clauseName
2083 <<
"' arguments must be defined by 'omp.map.info' ops";
2087LogicalResult MapInfoOp::verify() {
2088 if (getMapperId() &&
2090 *
this, getMapperIdAttr())) {
2105 const TargetDataOperands &clauses) {
2106 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2107 clauses.mapVars, clauses.useDeviceAddrVars,
2108 clauses.useDevicePtrVars);
2111LogicalResult TargetDataOp::verify() {
2112 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2113 getUseDeviceAddrVars().empty()) {
2114 return ::emitError(this->getLoc(),
2115 "At least one of map, use_device_ptr_vars, or "
2116 "use_device_addr_vars operand must be present");
2120 getUseDevicePtrVars())))
2124 getUseDeviceAddrVars())))
2134void TargetEnterDataOp::build(
2138 TargetEnterDataOp::build(builder, state,
2140 clauses.dependVars, clauses.device, clauses.ifExpr,
2141 clauses.mapVars, clauses.nowait);
2144LogicalResult TargetEnterDataOp::verify() {
2145 LogicalResult verifyDependVars =
2147 return failed(verifyDependVars) ? verifyDependVars
2158 TargetExitDataOp::build(builder, state,
2160 clauses.dependVars, clauses.device, clauses.ifExpr,
2161 clauses.mapVars, clauses.nowait);
2164LogicalResult TargetExitDataOp::verify() {
2165 LogicalResult verifyDependVars =
2167 return failed(verifyDependVars) ? verifyDependVars
2178 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
2179 clauses.dependVars, clauses.device, clauses.ifExpr,
2180 clauses.mapVars, clauses.nowait);
2183LogicalResult TargetUpdateOp::verify() {
2184 LogicalResult verifyDependVars =
2186 return failed(verifyDependVars) ? verifyDependVars
2195 const TargetOperands &clauses) {
2199 TargetOp::build(builder, state, {}, {},
2201 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
2202 clauses.hostEvalVars, clauses.ifExpr,
2204 nullptr, clauses.isDevicePtrVars,
2205 clauses.mapVars, clauses.nowait, clauses.privateVars,
2207 clauses.privateNeedsBarrier, clauses.threadLimit,
2211LogicalResult TargetOp::verify() {
2216 getHasDeviceAddrVars())))
2225LogicalResult TargetOp::verifyRegions() {
2226 auto teamsOps = getOps<TeamsOp>();
2227 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2228 return emitError(
"target containing multiple 'omp.teams' nested ops");
2231 Operation *capturedOp = getInnermostCapturedOmpOp();
2232 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2233 for (
Value hostEvalArg :
2234 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2236 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
2237 if (llvm::is_contained({teamsOp.getNumTeamsLower(),
2238 teamsOp.getNumTeamsUpper(),
2239 teamsOp.getThreadLimit()},
2243 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
2244 "and 'thread_limit' in 'omp.teams'";
2246 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
2247 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2248 parallelOp->isAncestor(capturedOp) &&
2249 hostEvalArg == parallelOp.getNumThreads())
2253 <<
"host_eval argument only legal as 'num_threads' in "
2254 "'omp.parallel' when representing target SPMD";
2256 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2257 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2258 loopNestOp.getOperation() == capturedOp &&
2259 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2260 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2261 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2264 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2265 "and steps in 'omp.loop_nest' when trip count "
2266 "must be evaluated in the host";
2269 return emitOpError() <<
"host_eval argument illegal use in '"
2270 << user->getName() <<
"' operation";
2279 assert(rootOp &&
"expected valid operation");
2296 bool isOmpDialect = op->
getDialect() == ompDialect;
2298 if (!isOmpDialect || !hasRegions)
2305 if (checkSingleMandatoryExec) {
2310 if (successor->isReachable(parentBlock))
2313 for (
Block &block : *parentRegion)
2315 !domInfo.
dominates(parentBlock, &block))
2322 if (&sibling != op && !siblingAllowedFn(&sibling))
2335Operation *TargetOp::getInnermostCapturedOmpOp() {
2336 auto *ompDialect =
getContext()->getLoadedDialect<omp::OpenMPDialect>();
2348 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2351 memOp.getEffects(effects);
2352 return !llvm::any_of(
2354 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2355 isa<SideEffects::AutomaticAllocationScopeResource>(
2365 WsloopOp *wsLoopOp) {
2367 if (teamsOp.getNumTeamsUpper())
2371 if (teamsOp.getNumReductionVars())
2373 if (wsLoopOp->getNumReductionVars())
2377 OffloadModuleInterface offloadMod =
2381 auto ompFlags = offloadMod.getFlags();
2384 return ompFlags.getAssumeTeamsOversubscription() &&
2385 ompFlags.getAssumeThreadsOversubscription();
2388TargetRegionFlags TargetOp::getKernelExecFlags(
Operation *capturedOp) {
2393 assert((!capturedOp ||
2394 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2395 "unexpected captured op");
2398 if (!isa_and_present<LoopNestOp>(capturedOp))
2399 return TargetRegionFlags::generic;
2403 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2404 assert(!loopWrappers.empty());
2406 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2407 if (isa<SimdOp>(innermostWrapper))
2408 innermostWrapper = std::next(innermostWrapper);
2410 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2411 if (numWrappers != 1 && numWrappers != 2)
2412 return TargetRegionFlags::generic;
2415 if (numWrappers == 2) {
2416 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2418 return TargetRegionFlags::generic;
2420 innermostWrapper = std::next(innermostWrapper);
2421 if (!isa<DistributeOp>(innermostWrapper))
2422 return TargetRegionFlags::generic;
2425 if (!isa_and_present<ParallelOp>(parallelOp))
2426 return TargetRegionFlags::generic;
2428 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->
getParentOp());
2430 return TargetRegionFlags::generic;
2432 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2433 TargetRegionFlags
result =
2434 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2441 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2443 if (!isa_and_present<TeamsOp>(teamsOp))
2444 return TargetRegionFlags::generic;
2446 if (teamsOp->
getParentOp() != targetOp.getOperation())
2447 return TargetRegionFlags::generic;
2449 if (isa<LoopOp>(innermostWrapper))
2450 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2460 Dialect *ompDialect = targetOp->getDialect();
2464 return sibling && (ompDialect != sibling->
getDialect() ||
2468 TargetRegionFlags
result =
2469 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2474 while (nestedCapture->
getParentOp() != capturedOp)
2477 return isa<ParallelOp>(nestedCapture) ?
result | TargetRegionFlags::spmd
2481 else if (isa<WsloopOp>(innermostWrapper)) {
2483 if (!isa_and_present<ParallelOp>(parallelOp))
2484 return TargetRegionFlags::generic;
2486 if (parallelOp->
getParentOp() == targetOp.getOperation())
2487 return TargetRegionFlags::spmd;
2490 return TargetRegionFlags::generic;
2499 ParallelOp::build(builder, state,
ValueRange(),
2510 const ParallelOperands &clauses) {
2512 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2513 clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2515 clauses.privateNeedsBarrier, clauses.procBindKind,
2516 clauses.reductionMod, clauses.reductionVars,
2521template <
typename OpType>
2523 auto privateVars = op.getPrivateVars();
2524 auto privateSyms = op.getPrivateSymsAttr();
2526 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2529 auto numPrivateVars = privateVars.size();
2530 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2532 if (numPrivateVars != numPrivateSyms)
2533 return op.emitError() <<
"inconsistent number of private variables and "
2534 "privatizer op symbols, private vars: "
2536 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2538 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2539 Type varType = std::get<0>(privateVarInfo).getType();
2540 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2541 PrivateClauseOp privatizerOp =
2544 if (privatizerOp ==
nullptr)
2545 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2546 << privateSym <<
"'";
2548 Type privatizerType = privatizerOp.getArgType();
2550 if (privatizerType && (varType != privatizerType))
2551 return op.emitError()
2552 <<
"type mismatch between a "
2553 << (privatizerOp.getDataSharingType() ==
2554 DataSharingClauseType::Private
2557 <<
" variable and its privatizer op, var type: " << varType
2558 <<
" vs. privatizer op type: " << privatizerType;
2564LogicalResult ParallelOp::verify() {
2565 if (getAllocateVars().size() != getAllocatorVars().size())
2567 "expected equal sizes for allocate and allocator variables");
2573 getReductionByref());
2576LogicalResult ParallelOp::verifyRegions() {
2577 auto distChildOps = getOps<DistributeOp>();
2578 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2579 if (numDistChildOps > 1)
2581 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2583 if (numDistChildOps == 1) {
2586 <<
"'omp.composite' attribute missing from composite operation";
2588 auto *ompDialect =
getContext()->getLoadedDialect<OpenMPDialect>();
2589 Operation &distributeOp = **distChildOps.begin();
2591 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2595 return emitError() <<
"unexpected OpenMP operation inside of composite "
2597 << childOp.getName();
2599 }
else if (isComposite()) {
2601 <<
"'omp.composite' attribute present in non-composite operation";
2618 const TeamsOperands &clauses) {
2621 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2622 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2624 nullptr, clauses.reductionMod,
2625 clauses.reductionVars,
2628 clauses.threadLimit);
2631LogicalResult TeamsOp::verify() {
2640 return emitError(
"expected to be nested inside of omp.target or not nested "
2641 "in any OpenMP dialect operations");
2644 if (
auto numTeamsLowerBound = getNumTeamsLower()) {
2645 auto numTeamsUpperBound = getNumTeamsUpper();
2646 if (!numTeamsUpperBound)
2647 return emitError(
"expected num_teams upper bound to be defined if the "
2648 "lower bound is defined");
2649 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2651 "expected num_teams upper bound and lower bound to be the same type");
2655 if (getAllocateVars().size() != getAllocatorVars().size())
2657 "expected equal sizes for allocate and allocator variables");
2660 getReductionByref());
2668 return getParentOp().getPrivateVars();
2672 return getParentOp().getReductionVars();
2680 const SectionsOperands &clauses) {
2683 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2686 clauses.reductionMod, clauses.reductionVars,
2691LogicalResult SectionsOp::verify() {
2692 if (getAllocateVars().size() != getAllocatorVars().size())
2694 "expected equal sizes for allocate and allocator variables");
2697 getReductionByref());
2700LogicalResult SectionsOp::verifyRegions() {
2701 for (
auto &inst : *getRegion().begin()) {
2702 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2704 <<
"expected omp.section op or terminator op inside region";
2716 const SingleOperands &clauses) {
2719 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2720 clauses.copyprivateVars,
2721 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2726LogicalResult SingleOp::verify() {
2728 if (getAllocateVars().size() != getAllocatorVars().size())
2730 "expected equal sizes for allocate and allocator variables");
2733 getCopyprivateSyms());
2741 const WorkshareOperands &clauses) {
2742 WorkshareOp::build(builder, state, clauses.nowait);
2749LogicalResult WorkshareLoopWrapperOp::verify() {
2750 if (!(*this)->getParentOfType<WorkshareOp>())
2751 return emitOpError() <<
"must be nested in an omp.workshare";
2755LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2756 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2758 return emitOpError() <<
"expected to be a standalone loop wrapper";
2767LogicalResult LoopWrapperInterface::verifyImpl() {
2771 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2772 "and `SingleBlock` traits";
2775 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2778 if (range_size(region.
getOps()) != 1)
2780 <<
"loop wrapper does not contain exactly one nested op";
2783 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2784 return emitOpError() <<
"nested in loop wrapper is not another loop "
2785 "wrapper or `omp.loop_nest`";
2795 const LoopOperands &clauses) {
2798 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2800 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2801 clauses.reductionMod, clauses.reductionVars,
2806LogicalResult LoopOp::verify() {
2808 getReductionByref());
2811LogicalResult LoopOp::verifyRegions() {
2812 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2814 return emitOpError() <<
"expected to be a standalone loop wrapper";
2825 build(builder, state, {}, {},
2827 false,
nullptr,
nullptr,
2828 nullptr, {},
nullptr,
2839 const WsloopOperands &clauses) {
2844 {}, {}, clauses.linearVars,
2845 clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2846 clauses.ordered, clauses.privateVars,
2847 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2848 clauses.reductionMod, clauses.reductionVars,
2850 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2851 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2854LogicalResult WsloopOp::verify() {
2856 getReductionByref());
2859LogicalResult WsloopOp::verifyRegions() {
2860 bool isCompositeChildLeaf =
2861 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2863 if (LoopWrapperInterface nested = getNestedWrapper()) {
2866 <<
"'omp.composite' attribute missing from composite wrapper";
2870 if (!isa<SimdOp>(nested))
2871 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2873 }
else if (isComposite() && !isCompositeChildLeaf) {
2875 <<
"'omp.composite' attribute present in non-composite wrapper";
2876 }
else if (!isComposite() && isCompositeChildLeaf) {
2878 <<
"'omp.composite' attribute missing from composite wrapper";
2889 const SimdOperands &clauses) {
2892 SimdOp::build(builder, state, clauses.alignedVars,
2895 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2896 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
2897 clauses.privateNeedsBarrier, clauses.reductionMod,
2898 clauses.reductionVars,
2904LogicalResult SimdOp::verify() {
2905 if (getSimdlen().has_value() && getSafelen().has_value() &&
2906 getSimdlen().value() > getSafelen().value())
2908 <<
"simdlen clause and safelen clause are both present, but the "
2909 "simdlen value is not less than or equal to safelen value";
2917 bool isCompositeChildLeaf =
2918 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2920 if (!isComposite() && isCompositeChildLeaf)
2922 <<
"'omp.composite' attribute missing from composite wrapper";
2924 if (isComposite() && !isCompositeChildLeaf)
2926 <<
"'omp.composite' attribute present in non-composite wrapper";
2930 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2932 for (
const Attribute &sym : *privateSyms) {
2933 auto symRef = cast<SymbolRefAttr>(sym);
2934 omp::PrivateClauseOp privatizer =
2936 getOperation(), symRef);
2938 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
2939 if (privatizer.getDataSharingType() ==
2940 DataSharingClauseType::FirstPrivate)
2941 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
2948LogicalResult SimdOp::verifyRegions() {
2949 if (getNestedWrapper())
2950 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
2960 const DistributeOperands &clauses) {
2961 DistributeOp::build(builder, state, clauses.allocateVars,
2962 clauses.allocatorVars, clauses.distScheduleStatic,
2963 clauses.distScheduleChunkSize, clauses.order,
2964 clauses.orderMod, clauses.privateVars,
2966 clauses.privateNeedsBarrier);
2969LogicalResult DistributeOp::verify() {
2970 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2972 "dist_schedule_static being present";
2974 if (getAllocateVars().size() != getAllocatorVars().size())
2976 "expected equal sizes for allocate and allocator variables");
2981LogicalResult DistributeOp::verifyRegions() {
2982 if (LoopWrapperInterface nested = getNestedWrapper()) {
2985 <<
"'omp.composite' attribute missing from composite wrapper";
2988 if (isa<WsloopOp>(nested)) {
2990 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2991 !cast<ComposableOpInterface>(parentOp).isComposite()) {
2992 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
2993 "when a composite 'omp.parallel' is the direct "
2996 }
else if (!isa<SimdOp>(nested))
2997 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
2999 }
else if (isComposite()) {
3001 <<
"'omp.composite' attribute present in non-composite wrapper";
3011LogicalResult DeclareMapperInfoOp::verify() {
3015LogicalResult DeclareMapperOp::verifyRegions() {
3016 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3017 getRegion().getBlocks().front().getTerminator()))
3018 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
3027LogicalResult DeclareReductionOp::verifyRegions() {
3028 if (!getAllocRegion().empty()) {
3029 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3030 if (yieldOp.getResults().size() != 1 ||
3031 yieldOp.getResults().getTypes()[0] !=
getType())
3032 return emitOpError() <<
"expects alloc region to yield a value "
3033 "of the reduction type";
3037 if (getInitializerRegion().empty())
3038 return emitOpError() <<
"expects non-empty initializer region";
3039 Block &initializerEntryBlock = getInitializerRegion().
front();
3042 if (!getAllocRegion().empty())
3043 return emitOpError() <<
"expects two arguments to the initializer region "
3044 "when an allocation region is used";
3046 if (getAllocRegion().empty())
3047 return emitOpError() <<
"expects one argument to the initializer region "
3048 "when no allocation region is used";
3051 <<
"expects one or two arguments to the initializer region";
3055 if (arg.getType() !=
getType())
3056 return emitOpError() <<
"expects initializer region argument to match "
3057 "the reduction type";
3059 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3060 if (yieldOp.getResults().size() != 1 ||
3061 yieldOp.getResults().getTypes()[0] !=
getType())
3062 return emitOpError() <<
"expects initializer region to yield a value "
3063 "of the reduction type";
3066 if (getReductionRegion().empty())
3067 return emitOpError() <<
"expects non-empty reduction region";
3068 Block &reductionEntryBlock = getReductionRegion().
front();
3073 return emitOpError() <<
"expects reduction region with two arguments of "
3074 "the reduction type";
3075 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3076 if (yieldOp.getResults().size() != 1 ||
3077 yieldOp.getResults().getTypes()[0] !=
getType())
3078 return emitOpError() <<
"expects reduction region to yield a value "
3079 "of the reduction type";
3082 if (!getAtomicReductionRegion().empty()) {
3083 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
3087 return emitOpError() <<
"expects atomic reduction region with two "
3088 "arguments of the same type";
3089 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3092 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
3093 return emitOpError() <<
"expects atomic reduction region arguments to "
3094 "be accumulators containing the reduction type";
3097 if (getCleanupRegion().empty())
3099 Block &cleanupEntryBlock = getCleanupRegion().
front();
3102 return emitOpError() <<
"expects cleanup region with one argument "
3103 "of the reduction type";
3113 const TaskOperands &clauses) {
3115 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3116 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3117 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3119 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3120 clauses.priority, clauses.privateVars,
3122 clauses.privateNeedsBarrier, clauses.untied,
3123 clauses.eventHandle);
3126LogicalResult TaskOp::verify() {
3127 LogicalResult verifyDependVars =
3129 return failed(verifyDependVars)
3132 getInReductionVars(),
3133 getInReductionByref());
3141 const TaskgroupOperands &clauses) {
3143 TaskgroupOp::build(builder, state, clauses.allocateVars,
3144 clauses.allocatorVars, clauses.taskReductionVars,
3149LogicalResult TaskgroupOp::verify() {
3151 getTaskReductionVars(),
3152 getTaskReductionByref());
3160 const TaskloopOperands &clauses) {
3163 builder, state, clauses.allocateVars, clauses.allocatorVars,
3164 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3165 clauses.inReductionVars,
3167 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3168 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3169 clauses.privateVars,
3171 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3176LogicalResult TaskloopOp::verify() {
3177 if (getAllocateVars().size() != getAllocatorVars().size())
3179 "expected equal sizes for allocate and allocator variables");
3181 getReductionVars(), getReductionByref())) ||
3183 getInReductionVars(),
3184 getInReductionByref())))
3187 if (!getReductionVars().empty() && getNogroup())
3188 return emitError(
"if a reduction clause is present on the taskloop "
3189 "directive, the nogroup clause must not be specified");
3190 for (
auto var : getReductionVars()) {
3191 if (llvm::is_contained(getInReductionVars(), var))
3192 return emitError(
"the same list item cannot appear in both a reduction "
3193 "and an in_reduction clause");
3196 if (getGrainsize() && getNumTasks()) {
3198 "the grainsize clause and num_tasks clause are mutually exclusive and "
3199 "may not appear on the same taskloop directive");
3205LogicalResult TaskloopOp::verifyRegions() {
3206 if (LoopWrapperInterface nested = getNestedWrapper()) {
3209 <<
"'omp.composite' attribute missing from composite wrapper";
3213 if (!isa<SimdOp>(nested))
3214 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3215 }
else if (isComposite()) {
3217 <<
"'omp.composite' attribute present in non-composite wrapper";
3241 for (
auto &iv : ivs)
3242 iv.type = loopVarType;
3247 result.addAttribute(
"loop_inclusive", UnitAttr::get(ctx));
3263 "collapse_num_loops",
3268 auto parseTiles = [&]() -> ParseResult {
3272 tiles.push_back(
tile);
3281 if (tiles.size() > 0)
3300 Region ®ion = getRegion();
3302 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3303 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3304 if (getLoopInclusive())
3306 p <<
"step (" << getLoopSteps() <<
") ";
3307 if (
int64_t numCollapse = getCollapseNumLoops())
3308 if (numCollapse > 1)
3309 p <<
"collapse(" << numCollapse <<
") ";
3312 p <<
"tiles(" << tiles.value() <<
") ";
3318 const LoopNestOperands &clauses) {
3320 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3321 clauses.loopLowerBounds, clauses.loopUpperBounds,
3322 clauses.loopSteps, clauses.loopInclusive,
3326LogicalResult LoopNestOp::verify() {
3327 if (getLoopLowerBounds().empty())
3328 return emitOpError() <<
"must represent at least one loop";
3330 if (getLoopLowerBounds().size() != getIVs().size())
3331 return emitOpError() <<
"number of range arguments and IVs do not match";
3333 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3334 if (lb.getType() != iv.getType())
3336 <<
"range argument type does not match corresponding IV type";
3339 uint64_t numIVs = getIVs().size();
3341 if (
const auto &numCollapse = getCollapseNumLoops())
3342 if (numCollapse > numIVs)
3344 <<
"collapse value is larger than the number of loops";
3347 if (tiles.value().size() > numIVs)
3348 return emitOpError() <<
"too few canonical loops for tile dimensions";
3350 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3351 return emitOpError() <<
"expects parent op to be a loop wrapper";
3356void LoopNestOp::gatherWrappers(
3359 while (
auto wrapper =
3360 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3361 wrappers.push_back(wrapper);
3370std::tuple<NewCliOp, OpOperand *, OpOperand *>
3376 return {{},
nullptr,
nullptr};
3379 "Unexpected type of cli");
3385 auto op = cast<LoopTransformationInterface>(use.getOwner());
3387 unsigned opnum = use.getOperandNumber();
3388 if (op.isGeneratee(opnum)) {
3389 assert(!gen &&
"Each CLI may have at most one def");
3391 }
else if (op.isApplyee(opnum)) {
3392 assert(!cons &&
"Each CLI may have at most one consumer");
3395 llvm_unreachable(
"Unexpected operand for a CLI");
3399 return {create, gen, cons};
3422 std::string cliName{
"cli"};
3426 .Case([&](CanonicalLoopOp op) {
3429 .Case([&](UnrollHeuristicOp op) -> std::string {
3430 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3432 .Case([&](TileOp op) -> std::string {
3433 auto [generateesFirst, generateesCount] =
3434 op.getGenerateesODSOperandIndexAndLength();
3435 unsigned firstGrid = generateesFirst;
3436 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3437 unsigned end = generateesFirst + generateesCount;
3438 unsigned opnum =
generator->getOperandNumber();
3440 if (firstGrid <= opnum && opnum < firstIntratile) {
3441 unsigned gridnum = opnum - firstGrid + 1;
3442 return (
"grid" + Twine(gridnum)).str();
3444 if (firstIntratile <= opnum && opnum < end) {
3445 unsigned intratilenum = opnum - firstIntratile + 1;
3446 return (
"intratile" + Twine(intratilenum)).str();
3448 llvm_unreachable(
"Unexpected generatee argument");
3450 .DefaultUnreachable(
"TODO: Custom name for this operation");
3453 setNameFn(
result, cliName);
3456LogicalResult NewCliOp::verify() {
3457 Value cli = getResult();
3460 "Unexpected type of cli");
3466 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3468 unsigned opnum = use.getOperandNumber();
3469 if (op.isGeneratee(opnum)) {
3472 emitOpError(
"CLI must have at most one generator");
3474 .
append(
"first generator here:");
3476 .
append(
"second generator here:");
3481 }
else if (op.isApplyee(opnum)) {
3484 emitOpError(
"CLI must have at most one consumer");
3486 .
append(
"first consumer here:")
3490 .
append(
"second consumer here:")
3497 llvm_unreachable(
"Unexpected operand for a CLI");
3505 .
append(
"see consumer here: ")
3528 setNameFn(&getRegion().front(),
"body_entry");
3531void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
3539 p <<
'(' << getCli() <<
')';
3540 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
3541 <<
" in range(" << getTripCount() <<
") ";
3551 CanonicalLoopInfoType cliType =
3552 CanonicalLoopInfoType::get(parser.
getContext());
3577 if (parser.
parseRegion(*region, {inductionVariable}))
3582 result.operands.append(cliOperand);
3588 return mlir::success();
3591LogicalResult CanonicalLoopOp::verify() {
3594 if (!getRegion().empty()) {
3595 Region ®ion = getRegion();
3598 "Canonical loop region must have exactly one argument");
3602 "Region argument must be the same type as the trip count");
3608Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
3610std::pair<unsigned, unsigned>
3611CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3616std::pair<unsigned, unsigned>
3617CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3618 return getODSOperandIndexAndLength(odsIndex_cli);
3632 p <<
'(' << getApplyee() <<
')';
3639 auto cliType = CanonicalLoopInfoType::get(parser.
getContext());
3662 return mlir::success();
3665std::pair<unsigned, unsigned>
3666UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3667 return getODSOperandIndexAndLength(odsIndex_applyee);
3670std::pair<unsigned, unsigned>
3671UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3682 if (!generatees.empty())
3683 p <<
'(' << llvm::interleaved(generatees) <<
')';
3685 if (!applyees.empty())
3686 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
3717LogicalResult TileOp::verify() {
3718 if (getApplyees().empty())
3719 return emitOpError() <<
"must apply to at least one loop";
3721 if (getSizes().size() != getApplyees().size())
3722 return emitOpError() <<
"there must be one tile size for each applyee";
3724 if (!getGeneratees().empty() &&
3725 2 * getSizes().size() != getGeneratees().size())
3727 <<
"expecting two times the number of generatees than applyees";
3731 Value parent = getApplyees().front();
3732 for (
auto &&applyee : llvm::drop_begin(getApplyees())) {
3733 auto [parentCreate, parentGen, parentCons] =
decodeCli(parent);
3734 auto [create, gen, cons] =
decodeCli(applyee);
3737 return emitOpError() <<
"applyee CLI has no generator";
3739 auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
3742 <<
"currently only supports omp.canonical_loop as applyee";
3744 parentIVs.insert(parentLoop.getInductionVar());
3747 return emitOpError() <<
"applyee CLI has no generator";
3748 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3751 <<
"currently only supports omp.canonical_loop as applyee";
3757 auto &parentBody = parentLoop.getRegion();
3758 if (!parentBody.hasOneBlock())
3760 auto &parentBlock = parentBody.getBlocks().front();
3762 auto nestedLoopIt = parentBlock.begin();
3763 if (nestedLoopIt == parentBlock.end() ||
3764 (&*nestedLoopIt != loop.getOperation()))
3767 auto termIt = std::next(nestedLoopIt);
3768 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3771 if (std::next(termIt) != parentBlock.end())
3776 if (!isPerfectlyNested)
3777 return emitOpError() <<
"tiled loop nest must be perfectly nested";
3779 if (parentIVs.contains(loop.getTripCount()))
3780 return emitOpError() <<
"tiled loop nest must be rectangular";
3799std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
3800 return getODSOperandIndexAndLength(odsIndex_applyees);
3803std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
3804 return getODSOperandIndexAndLength(odsIndex_generatees);
3812 const CriticalDeclareOperands &clauses) {
3813 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3816LogicalResult CriticalDeclareOp::verify() {
3821 if (getNameAttr()) {
3822 SymbolRefAttr symbolRef = getNameAttr();
3826 return emitOpError() <<
"expected symbol reference " << symbolRef
3827 <<
" to point to a critical declaration";
3847 return op.
emitOpError() <<
"must be nested inside of a loop";
3851 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3852 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3854 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
3855 "have an ordered clause";
3857 if (hasRegion && orderedAttr.getInt() != 0)
3858 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
3859 "have a parameter present";
3861 if (!hasRegion && orderedAttr.getInt() == 0)
3862 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
3863 "have a parameter present";
3864 }
else if (!isa<SimdOp>(wrapper)) {
3865 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
3866 "or worksharing simd loop";
3872 const OrderedOperands &clauses) {
3873 OrderedOp::build(builder, state, clauses.doacrossDependType,
3874 clauses.doacrossNumLoops, clauses.doacrossDependVars);
3877LogicalResult OrderedOp::verify() {
3881 auto wrapper = (*this)->getParentOfType<WsloopOp>();
3882 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3883 return emitOpError() <<
"number of variables in depend clause does not "
3884 <<
"match number of iteration variables in the "
3891 const OrderedRegionOperands &clauses) {
3892 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3902 const TaskwaitOperands &clauses) {
3904 TaskwaitOp::build(builder, state,
nullptr,
3912LogicalResult AtomicReadOp::verify() {
3913 if (verifyCommon().
failed())
3914 return mlir::failure();
3916 if (
auto mo = getMemoryOrder()) {
3917 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3918 *mo == ClauseMemoryOrderKind::Release) {
3920 "memory-order must not be acq_rel or release for atomic reads");
3930LogicalResult AtomicWriteOp::verify() {
3931 if (verifyCommon().
failed())
3932 return mlir::failure();
3934 if (
auto mo = getMemoryOrder()) {
3935 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3936 *mo == ClauseMemoryOrderKind::Acquire) {
3938 "memory-order must not be acq_rel or acquire for atomic writes");
3948LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3954 if (
Value writeVal = op.getWriteOpVal()) {
3956 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3962LogicalResult AtomicUpdateOp::verify() {
3963 if (verifyCommon().
failed())
3964 return mlir::failure();
3966 if (
auto mo = getMemoryOrder()) {
3967 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3968 *mo == ClauseMemoryOrderKind::Acquire) {
3970 "memory-order must not be acq_rel or acquire for atomic updates");
3977LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3983AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3984 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3986 return dyn_cast<AtomicReadOp>(getSecondOp());
3989AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3990 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3992 return dyn_cast<AtomicWriteOp>(getSecondOp());
3995AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3996 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3998 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4001LogicalResult AtomicCaptureOp::verify() {
4005LogicalResult AtomicCaptureOp::verifyRegions() {
4006 if (verifyRegionsCommon().
failed())
4007 return mlir::failure();
4009 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
4011 "operations inside capture region must not have hint clause");
4013 if (getFirstOp()->getAttr(
"memory_order") ||
4014 getSecondOp()->getAttr(
"memory_order"))
4016 "operations inside capture region must not have memory_order clause");
4025 const CancelOperands &clauses) {
4026 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4039LogicalResult CancelOp::verify() {
4040 ClauseCancellationConstructType cct = getCancelDirective();
4043 if (!structuralParent)
4044 return emitOpError() <<
"Orphaned cancel construct";
4046 if ((cct == ClauseCancellationConstructType::Parallel) &&
4047 !mlir::isa<ParallelOp>(structuralParent)) {
4048 return emitOpError() <<
"cancel parallel must appear "
4049 <<
"inside a parallel region";
4051 if (cct == ClauseCancellationConstructType::Loop) {
4054 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
4058 <<
"cancel loop must appear inside a worksharing-loop region";
4060 if (wsloopOp.getNowaitAttr()) {
4061 return emitError() <<
"A worksharing construct that is canceled "
4062 <<
"must not have a nowait clause";
4064 if (wsloopOp.getOrderedAttr()) {
4065 return emitError() <<
"A worksharing construct that is canceled "
4066 <<
"must not have an ordered clause";
4069 }
else if (cct == ClauseCancellationConstructType::Sections) {
4073 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
4075 return emitOpError() <<
"cancel sections must appear "
4076 <<
"inside a sections region";
4078 if (sectionsOp.getNowait()) {
4079 return emitError() <<
"A sections construct that is canceled "
4080 <<
"must not have a nowait clause";
4083 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4084 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4085 !mlir::isa<omp::TaskloopOp>(structuralParent->
getParentOp()))) {
4086 return emitOpError() <<
"cancel taskgroup must appear "
4087 <<
"inside a task region";
4097 const CancellationPointOperands &clauses) {
4098 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4101LogicalResult CancellationPointOp::verify() {
4102 ClauseCancellationConstructType cct = getCancelDirective();
4105 if (!structuralParent)
4106 return emitOpError() <<
"Orphaned cancellation point";
4108 if ((cct == ClauseCancellationConstructType::Parallel) &&
4109 !mlir::isa<ParallelOp>(structuralParent)) {
4110 return emitOpError() <<
"cancellation point parallel must appear "
4111 <<
"inside a parallel region";
4115 if ((cct == ClauseCancellationConstructType::Loop) &&
4116 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
4117 return emitOpError() <<
"cancellation point loop must appear "
4118 <<
"inside a worksharing-loop region";
4120 if ((cct == ClauseCancellationConstructType::Sections) &&
4121 !mlir::isa<omp::SectionOp>(structuralParent)) {
4122 return emitOpError() <<
"cancellation point sections must appear "
4123 <<
"inside a sections region";
4125 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4126 !mlir::isa<omp::TaskOp>(structuralParent)) {
4127 return emitOpError() <<
"cancellation point taskgroup must appear "
4128 <<
"inside a task region";
4137LogicalResult MapBoundsOp::verify() {
4138 auto extent = getExtent();
4140 if (!extent && !upperbound)
4141 return emitError(
"expected extent or upperbound.");
4148 PrivateClauseOp::build(
4149 odsBuilder, odsState, symName, type,
4150 DataSharingClauseTypeAttr::get(odsBuilder.
getContext(),
4151 DataSharingClauseType::Private));
4154LogicalResult PrivateClauseOp::verifyRegions() {
4155 Type argType = getArgType();
4156 auto verifyTerminator = [&](
Operation *terminator,
4157 bool yieldsValue) -> LogicalResult {
4161 if (!llvm::isa<YieldOp>(terminator))
4163 <<
"expected exit block terminator to be an `omp.yield` op.";
4165 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4166 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4169 if (yieldedTypes.empty())
4173 <<
"Did not expect any values to be yielded.";
4176 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4180 <<
"Invalid yielded value. Expected type: " << argType
4183 if (yieldedTypes.empty())
4186 error << yieldedTypes;
4192 StringRef regionName,
4193 bool yieldsValue) -> LogicalResult {
4194 assert(!region.
empty());
4198 <<
"`" << regionName <<
"`: "
4199 <<
"expected " << expectedNumArgs
4202 for (
Block &block : region) {
4204 if (!block.mightHaveTerminator())
4207 if (
failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4215 for (
Region *region : getRegions())
4216 for (
Type ty : region->getArgumentTypes())
4218 return emitError() <<
"Region argument type mismatch: got " << ty
4219 <<
" expected " << argType <<
".";
4222 if (!initRegion.
empty() &&
4227 DataSharingClauseType dsType = getDataSharingType();
4229 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4230 return emitError(
"`private` clauses do not require a `copy` region.");
4232 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4234 "`firstprivate` clauses require at least a `copy` region.");
4236 if (dsType == DataSharingClauseType::FirstPrivate &&
4241 if (!getDeallocRegion().empty() &&
4254 const MaskedOperands &clauses) {
4255 MaskedOp::build(builder, state, clauses.filteredThreadId);
4263 const ScanOperands &clauses) {
4264 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4267LogicalResult ScanOp::verify() {
4268 if (hasExclusiveVars() == hasInclusiveVars())
4270 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4271 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4272 if (parentWsLoopOp.getReductionModAttr() &&
4273 parentWsLoopOp.getReductionModAttr().getValue() ==
4274 ReductionModifier::inscan)
4277 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4278 if (parentSimdOp.getReductionModAttr() &&
4279 parentSimdOp.getReductionModAttr().getValue() ==
4280 ReductionModifier::inscan)
4283 return emitError(
"SCAN directive needs to be enclosed within a parent "
4284 "worksharing loop construct or SIMD construct with INSCAN "
4285 "reduction modifier");
4290LogicalResult AllocateDirOp::verify() {
4291 std::optional<uint64_t> align = this->getAlign();
4293 if (align.has_value()) {
4294 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4295 return emitError() <<
"ALIGN value : " << align.value()
4296 <<
" must be power of 2";
4306mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4307 return getInTypeAttr().getValue();
4316 bool hasOperands =
false;
4317 std::int32_t typeparamsSize = 0;
4323 return mlir::failure();
4325 return mlir::failure();
4327 return mlir::failure();
4331 return mlir::failure();
4332 result.addAttribute(
"in_type", mlir::TypeAttr::get(intype));
4339 return mlir::failure();
4340 typeparamsSize = operands.size();
4343 std::int32_t shapeSize = 0;
4347 return mlir::failure();
4348 shapeSize = operands.size() - typeparamsSize;
4350 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4351 typeVec.push_back(idxTy);
4357 return mlir::failure();
4362 return mlir::failure();
4365 result.addAttribute(
"operandSegmentSizes",
4369 return mlir::failure();
4370 return mlir::success();
4385 if (!getTypeparams().empty()) {
4386 p <<
'(' << getTypeparams() <<
" : " << getTypeparams().getTypes() <<
')';
4393 {
"in_type",
"operandSegmentSizes"});
4396llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4398 if (!mlir::dyn_cast<IntegerType>(outType))
4400 return mlir::success();
4407LogicalResult WorkdistributeOp::verify() {
4409 Region ®ion = getRegion();
4414 if (entryBlock.
empty())
4415 return emitOpError(
"region must contain a structured block");
4417 bool hasTerminator =
false;
4418 for (
Block &block : region) {
4419 if (isa<TerminatorOp>(block.back())) {
4420 if (hasTerminator) {
4421 return emitOpError(
"region must have exactly one terminator");
4423 hasTerminator =
true;
4426 if (!hasTerminator) {
4427 return emitOpError(
"region must be terminated with omp.terminator");
4431 if (isa<BarrierOp>(op)) {
4433 "explicit barriers are not allowed in workdistribute region");
4436 if (isa<ParallelOp>(op)) {
4438 "nested parallel constructs not allowed in workdistribute");
4440 if (isa<TeamsOp>(op)) {
4442 "nested teams constructs not allowed in workdistribute");
4446 if (walkResult.wasInterrupted())
4450 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4451 return emitOpError(
"workdistribute must be nested under teams");
4455#define GET_ATTRDEF_CLASSES
4456#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4458#define GET_OP_CLASSES
4459#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4461#define GET_TYPEDEF_CLASSES
4462#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
*attr dict without keyword *static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
bool isPerfectlyNested(ArrayRef< AffineForOp > loops)
Returns true if loops is a perfectly nested loop nest, where loops appear in it from outermost to inn...
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.