26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/PostOrderIterator.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/STLForwardCompat.h"
30#include "llvm/ADT/SmallString.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/ADT/bit.h"
35#include "llvm/Support/InterleavedRange.h"
41#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
51 return attrs.empty() ?
nullptr : ArrayAttr::get(context, attrs);
65struct MemRefPointerLikeModel
66 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
69 return llvm::cast<MemRefType>(pointer).getElementType();
73struct LLVMPointerPointerLikeModel
74 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
75 LLVM::LLVMPointerType> {
100 bool isRegionArgOfOp;
110 assert(isRegionArgOfOp &&
"Must describe a region operand");
113 size_t &getArgIdx() {
114 assert(isRegionArgOfOp &&
"Must describe a region operand");
119 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
123 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
126 bool isLoopOp()
const {
127 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
128 return isa<CanonicalLoopOp>(op);
130 Region *&getParentRegion() {
131 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
134 size_t &getLoopDepth() {
135 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
139 void skipIf(
bool v =
true) { skip = skip || v; }
157 llvm::ReversePostOrderTraversal<Block *> traversal(&r->
getBlocks().front());
160 size_t sequentialIdx = -1;
161 bool isOnlyContainerOp =
true;
162 for (
Block *
b : traversal) {
164 if (&op == o && !found) {
168 if (op.getNumRegions()) {
171 isOnlyContainerOp =
false;
173 if (found && !isOnlyContainerOp)
178 Component &containerOpInRegion = components.emplace_back();
179 containerOpInRegion.isRegionArgOfOp =
false;
180 containerOpInRegion.isUnique = isOnlyContainerOp;
181 containerOpInRegion.getContainerOp() = o;
182 containerOpInRegion.getOpPos() = sequentialIdx;
183 containerOpInRegion.getParentRegion() = r;
188 Component ®ionArgOfOperation = components.emplace_back();
189 regionArgOfOperation.isRegionArgOfOp =
true;
190 regionArgOfOperation.isUnique =
true;
191 regionArgOfOperation.getArgIdx() = 0;
192 regionArgOfOperation.getOwnerOp() = parent;
204 for (
auto [idx, region] : llvm::enumerate(o->
getRegions())) {
208 llvm_unreachable(
"Region not child of its parent operation");
210 regionArgOfOperation.isUnique =
false;
211 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
219 for (Component &c : components)
220 c.skipIf(c.isRegionArgOfOp && c.isUnique);
223 size_t numSurroundingLoops = 0;
224 for (Component &c : llvm::reverse(components)) {
229 if (c.isRegionArgOfOp) {
230 numSurroundingLoops = 0;
237 numSurroundingLoops = 0;
239 c.getLoopDepth() = numSurroundingLoops;
242 if (isa<CanonicalLoopOp>(c.getContainerOp()))
243 numSurroundingLoops += 1;
248 bool isLoopNest =
false;
249 for (Component &c : components) {
250 if (c.skip || c.isRegionArgOfOp)
253 if (!isLoopNest && c.getLoopDepth() >= 1) {
256 }
else if (isLoopNest) {
258 c.skipIf(c.isUnique);
262 if (c.getLoopDepth() == 0)
269 for (Component &c : components)
270 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
271 !isa<CanonicalLoopOp>(c.getContainerOp()));
275 bool newRegion =
true;
276 for (Component &c : llvm::reverse(components)) {
277 c.skipIf(newRegion && c.isUnique);
284 if (!c.isRegionArgOfOp && c.getContainerOp())
290 llvm::raw_svector_ostream NameOS(Name);
291 for (
auto &c : llvm::reverse(components)) {
295 if (c.isRegionArgOfOp)
296 NameOS <<
"_r" << c.getArgIdx();
297 else if (c.getLoopDepth() >= 1)
298 NameOS <<
"_d" << c.getLoopDepth();
300 NameOS <<
"_s" << c.getOpPos();
303 return NameOS.str().str();
306void OpenMPDialect::initialize() {
309#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
312#define GET_ATTRDEF_LIST
313#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
316#define GET_TYPEDEF_LIST
317#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
320 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
322 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
323 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
328 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
334 mlir::LLVM::GlobalOp::attachInterface<
337 mlir::LLVM::LLVMFuncOp::attachInterface<
340 mlir::func::FuncOp::attachInterface<
366 allocatorVars.push_back(operand);
367 allocatorTypes.push_back(type);
373 allocateVars.push_back(operand);
374 allocateTypes.push_back(type);
385 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
386 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
387 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
388 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
396template <
typename ClauseAttr>
398 using ClauseT =
decltype(std::declval<ClauseAttr>().getValue());
403 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
404 attr = ClauseAttr::get(parser.
getContext(), *enumValue);
407 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
410template <
typename ClauseAttr>
412 p << stringifyEnum(attr.getValue());
437 linearVars.push_back(var);
438 linearTypes.push_back(type);
439 linearStepVars.push_back(stepVar);
440 linearStepTypes.push_back(stepType);
450 size_t linearVarsSize = linearVars.size();
451 for (
unsigned i = 0; i < linearVarsSize; ++i) {
452 std::string separator = i == linearVarsSize - 1 ?
"" :
", ";
453 p << linearVars[i] <<
" : " << linearTypes[i];
454 p <<
" = " << linearStepVars[i] <<
" : " << stepVarTypes[i];
468 for (
const auto &it : nontemporalVars)
469 if (!nontemporalItems.insert(it).second)
470 return op->
emitOpError() <<
"nontemporal variable used more than once";
479 std::optional<ArrayAttr> alignments,
482 if (!alignedVars.empty()) {
483 if (!alignments || alignments->size() != alignedVars.size())
485 <<
"expected as many alignment values as aligned variables";
488 return op->
emitOpError() <<
"unexpected alignment values attribute";
494 for (
auto it : alignedVars)
495 if (!alignedItems.insert(it).second)
496 return op->
emitOpError() <<
"aligned variable used more than once";
502 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
503 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
504 if (intAttr.getValue().sle(0))
505 return op->
emitOpError() <<
"alignment should be greater than 0";
507 return op->
emitOpError() <<
"expected integer alignment";
524 if (parser.parseOperand(alignedVars.emplace_back()) ||
525 parser.parseColonType(alignedTypes.emplace_back()) ||
526 parser.parseArrow() ||
527 parser.parseAttribute(alignmentVec.emplace_back())) {
534 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
541 std::optional<ArrayAttr> alignments) {
542 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
545 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
546 p <<
" -> " << (*alignments)[i];
557 if (modifiers.size() > 2)
559 for (
const auto &mod : modifiers) {
562 auto symbol = symbolizeScheduleModifier(mod);
565 <<
" unknown modifier type: " << mod;
570 if (modifiers.size() == 1) {
571 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
572 modifiers.push_back(modifiers[0]);
573 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
575 }
else if (modifiers.size() == 2) {
578 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
579 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
581 <<
" incorrect modifier order";
597 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
598 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
603 std::optional<mlir::omp::ClauseScheduleKind> schedule =
604 symbolizeClauseScheduleKind(keyword);
608 scheduleAttr = ClauseScheduleKindAttr::get(parser.
getContext(), *schedule);
610 case ClauseScheduleKind::Static:
611 case ClauseScheduleKind::Dynamic:
612 case ClauseScheduleKind::Guided:
618 chunkSize = std::nullopt;
621 case ClauseScheduleKind::Auto:
622 case ClauseScheduleKind::Runtime:
623 case ClauseScheduleKind::Distribute:
624 chunkSize = std::nullopt;
633 modifiers.push_back(mod);
639 if (!modifiers.empty()) {
641 if (std::optional<ScheduleModifier> mod =
642 symbolizeScheduleModifier(modifiers[0])) {
643 scheduleMod = ScheduleModifierAttr::get(parser.
getContext(), *mod);
645 return parser.
emitError(loc,
"invalid schedule modifier");
648 if (modifiers.size() > 1) {
649 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
659 ClauseScheduleKindAttr scheduleKind,
660 ScheduleModifierAttr scheduleMod,
661 UnitAttr scheduleSimd,
Value scheduleChunk,
662 Type scheduleChunkType) {
663 p << stringifyClauseScheduleKind(scheduleKind.getValue());
665 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
667 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
679 ClauseOrderKindAttr &order,
680 OrderModifierAttr &orderMod) {
685 if (std::optional<OrderModifier> enumValue =
686 symbolizeOrderModifier(enumStr)) {
687 orderMod = OrderModifierAttr::get(parser.
getContext(), *enumValue);
694 if (std::optional<ClauseOrderKind> enumValue =
695 symbolizeClauseOrderKind(enumStr)) {
696 order = ClauseOrderKindAttr::get(parser.
getContext(), *enumValue);
699 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
703 ClauseOrderKindAttr order,
704 OrderModifierAttr orderMod) {
706 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
708 p << stringifyClauseOrderKind(order.getValue());
711template <
typename ClauseTypeAttr,
typename ClauseType>
714 std::optional<OpAsmParser::UnresolvedOperand> &operand,
716 std::optional<ClauseType> (*symbolizeClause)(StringRef),
717 StringRef clauseName) {
720 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
721 prescriptiveness = ClauseTypeAttr::get(parser.
getContext(), *enumValue);
726 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
736 <<
"expected " << clauseName <<
" operand";
739 if (operand.has_value()) {
747template <
typename ClauseTypeAttr,
typename ClauseType>
750 ClauseTypeAttr prescriptiveness,
Value operand,
752 StringRef (*stringifyClauseType)(ClauseType)) {
754 if (prescriptiveness)
755 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
758 p << operand <<
": " << operandType;
768 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
769 Type &grainsizeType) {
771 parser, grainsizeMod, grainsize, grainsizeType,
772 &symbolizeClauseGrainsizeType,
"grainsize");
776 ClauseGrainsizeTypeAttr grainsizeMod,
779 p, op, grainsizeMod, grainsize, grainsizeType,
780 &stringifyClauseGrainsizeType);
790 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
791 Type &numTasksType) {
793 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
798 ClauseNumTasksTypeAttr numTasksMod,
801 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
810 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
811 SmallVectorImpl<Type> &types;
812 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
813 SmallVectorImpl<Type> &types)
814 : vars(vars), types(types) {}
816struct PrivateParseArgs {
817 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
818 llvm::SmallVectorImpl<Type> &types;
820 UnitAttr &needsBarrier;
822 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
823 SmallVectorImpl<Type> &types,
ArrayAttr &syms,
824 UnitAttr &needsBarrier,
826 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
827 mapIndices(mapIndices) {}
830struct ReductionParseArgs {
831 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
832 SmallVectorImpl<Type> &types;
835 ReductionModifierAttr *modifier;
836 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
838 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
839 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
842struct AllRegionParseArgs {
843 std::optional<MapParseArgs> hasDeviceAddrArgs;
844 std::optional<MapParseArgs> hostEvalArgs;
845 std::optional<ReductionParseArgs> inReductionArgs;
846 std::optional<MapParseArgs> mapArgs;
847 std::optional<PrivateParseArgs> privateArgs;
848 std::optional<ReductionParseArgs> reductionArgs;
849 std::optional<ReductionParseArgs> taskReductionArgs;
850 std::optional<MapParseArgs> useDeviceAddrArgs;
851 std::optional<MapParseArgs> useDevicePtrArgs;
856 return "private_barrier";
866 ReductionModifierAttr *modifier =
nullptr,
867 UnitAttr *needsBarrier =
nullptr) {
871 unsigned regionArgOffset = regionPrivateArgs.size();
881 std::optional<ReductionModifier> enumValue =
882 symbolizeReductionModifier(enumStr);
883 if (!enumValue.has_value())
885 *modifier = ReductionModifierAttr::get(parser.
getContext(), *enumValue);
892 isByRefVec.push_back(
893 parser.parseOptionalKeyword(
"byref").succeeded());
895 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
898 if (parser.parseOperand(operands.emplace_back()) ||
899 parser.parseArrow() ||
900 parser.parseArgument(regionPrivateArgs.emplace_back()))
904 if (parser.parseOptionalLSquare().succeeded()) {
905 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
906 parser.parseInteger(mapIndicesVec.emplace_back()) ||
907 parser.parseRSquare())
910 mapIndicesVec.push_back(-1);
922 if (parser.parseType(types.emplace_back()))
929 if (operands.size() != types.size())
938 *needsBarrier = mlir::UnitAttr::get(parser.
getContext());
941 auto *argsBegin = regionPrivateArgs.begin();
943 argsBegin + regionArgOffset + types.size());
944 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
950 *symbols = ArrayAttr::get(parser.
getContext(), symbolAttrs);
953 if (!mapIndicesVec.empty())
966 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
981 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
987 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
988 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
989 nullptr, &privateArgs->needsBarrier)))
998 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1003 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1004 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
1005 reductionArgs->modifier)))
1012 AllRegionParseArgs args) {
1016 args.hasDeviceAddrArgs)))
1018 <<
"invalid `has_device_addr` format";
1021 args.hostEvalArgs)))
1023 <<
"invalid `host_eval` format";
1026 args.inReductionArgs)))
1028 <<
"invalid `in_reduction` format";
1033 <<
"invalid `map_entries` format";
1038 <<
"invalid `private` format";
1041 args.reductionArgs)))
1043 <<
"invalid `reduction` format";
1046 args.taskReductionArgs)))
1048 <<
"invalid `task_reduction` format";
1051 args.useDeviceAddrArgs)))
1053 <<
"invalid `use_device_addr` format";
1056 args.useDevicePtrArgs)))
1058 <<
"invalid `use_device_addr` format";
1060 return parser.
parseRegion(region, entryBlockArgs);
1079 AllRegionParseArgs args;
1080 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1081 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1082 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1083 inReductionByref, inReductionSyms);
1084 args.mapArgs.emplace(mapVars, mapTypes);
1085 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1086 privateNeedsBarrier, &privateMaps);
1097 UnitAttr &privateNeedsBarrier) {
1098 AllRegionParseArgs args;
1099 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1100 inReductionByref, inReductionSyms);
1101 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1102 privateNeedsBarrier);
1113 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1117 AllRegionParseArgs args;
1118 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1119 inReductionByref, inReductionSyms);
1120 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1121 privateNeedsBarrier);
1122 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1123 reductionSyms, &reductionMod);
1131 UnitAttr &privateNeedsBarrier) {
1132 AllRegionParseArgs args;
1133 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1134 privateNeedsBarrier);
1142 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1146 AllRegionParseArgs args;
1147 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1148 privateNeedsBarrier);
1149 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1150 reductionSyms, &reductionMod);
1159 AllRegionParseArgs args;
1160 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1161 taskReductionByref, taskReductionSyms);
1171 AllRegionParseArgs args;
1172 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1173 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1182struct MapPrintArgs {
1187struct PrivatePrintArgs {
1191 UnitAttr needsBarrier;
1195 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1196 mapIndices(mapIndices) {}
1198struct ReductionPrintArgs {
1203 ReductionModifierAttr modifier;
1205 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
1206 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1208struct AllRegionPrintArgs {
1209 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1210 std::optional<MapPrintArgs> hostEvalArgs;
1211 std::optional<ReductionPrintArgs> inReductionArgs;
1212 std::optional<MapPrintArgs> mapArgs;
1213 std::optional<PrivatePrintArgs> privateArgs;
1214 std::optional<ReductionPrintArgs> reductionArgs;
1215 std::optional<ReductionPrintArgs> taskReductionArgs;
1216 std::optional<MapPrintArgs> useDeviceAddrArgs;
1217 std::optional<MapPrintArgs> useDevicePtrArgs;
1226 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
1227 if (argsSubrange.empty())
1230 p << clauseName <<
"(";
1233 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1237 symbols = ArrayAttr::get(ctx, values);
1250 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1251 mapIndices.asArrayRef(),
1252 byref.asArrayRef()),
1254 auto [op, arg, sym, map, isByRef] = t;
1260 p << op <<
" -> " << arg;
1263 p <<
" [map_idx=" << map <<
"]";
1266 llvm::interleaveComma(types, p);
1274 StringRef clauseName,
ValueRange argsSubrange,
1275 std::optional<MapPrintArgs> mapArgs) {
1282 StringRef clauseName,
ValueRange argsSubrange,
1283 std::optional<PrivatePrintArgs> privateArgs) {
1286 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1287 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1288 nullptr, privateArgs->needsBarrier);
1294 std::optional<ReductionPrintArgs> reductionArgs) {
1297 reductionArgs->vars, reductionArgs->types,
1298 reductionArgs->syms,
nullptr,
1299 reductionArgs->byref, reductionArgs->modifier);
1303 const AllRegionPrintArgs &args) {
1304 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1308 iface.getHasDeviceAddrBlockArgs(),
1309 args.hasDeviceAddrArgs);
1313 args.inReductionArgs);
1319 args.reductionArgs);
1321 iface.getTaskReductionBlockArgs(),
1322 args.taskReductionArgs);
1324 iface.getUseDeviceAddrBlockArgs(),
1325 args.useDeviceAddrArgs);
1327 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1343 AllRegionPrintArgs args;
1344 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1345 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1346 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1347 inReductionByref, inReductionSyms);
1348 args.mapArgs.emplace(mapVars, mapTypes);
1349 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1350 privateNeedsBarrier, privateMaps);
1358 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1359 AllRegionPrintArgs args;
1360 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1361 inReductionByref, inReductionSyms);
1362 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1363 privateNeedsBarrier,
1372 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1373 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1376 AllRegionPrintArgs args;
1377 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1378 inReductionByref, inReductionSyms);
1379 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1380 privateNeedsBarrier,
1382 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1383 reductionSyms, reductionMod);
1390 UnitAttr privateNeedsBarrier) {
1391 AllRegionPrintArgs args;
1392 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1393 privateNeedsBarrier,
1401 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1404 AllRegionPrintArgs args;
1405 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1406 privateNeedsBarrier,
1408 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1409 reductionSyms, reductionMod);
1419 AllRegionPrintArgs args;
1420 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1421 taskReductionByref, taskReductionSyms);
1431 AllRegionPrintArgs args;
1432 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1433 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1437template <
typename ParsePrefixFn>
1446 if (failed(parsePrefix()))
1454 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1455 iteratedVars.push_back(v);
1456 iteratedTypes.push_back(ty);
1458 plainVars.push_back(v);
1459 plainTypes.push_back(ty);
1465template <
typename Pr
intPrefixFn>
1469 PrintPrefixFn &&printPrefixForPlain,
1470 PrintPrefixFn &&printPrefixForIterated) {
1477 p << v <<
" : " << t;
1481 for (
unsigned i = 0; i < iteratedVars.size(); ++i)
1482 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1483 for (
unsigned i = 0; i < plainVars.size(); ++i)
1484 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1492 if (!reductionVars.empty()) {
1493 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1495 <<
"expected as many reduction symbol references "
1496 "as reduction variables";
1497 if (reductionByref && reductionByref->size() != reductionVars.size())
1498 return op->
emitError() <<
"expected as many reduction variable by "
1499 "reference attributes as reduction variables";
1502 return op->
emitOpError() <<
"unexpected reduction symbol references";
1509 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1510 Value accum = std::get<0>(args);
1512 if (!accumulators.insert(accum).second)
1513 return op->
emitOpError() <<
"accumulator variable used more than once";
1516 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1520 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1521 <<
" to point to a reduction declaration";
1523 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1525 <<
"expected accumulator (" << varType
1526 <<
") to be the same type as reduction declaration ("
1527 << decl.getAccumulatorType() <<
")";
1546 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1547 parser.parseArrow() ||
1548 parser.parseAttribute(symsVec.emplace_back()) ||
1549 parser.parseColonType(copyprivateTypes.emplace_back()))
1555 copyprivateSyms = ArrayAttr::get(parser.
getContext(), syms);
1563 std::optional<ArrayAttr> copyprivateSyms) {
1564 if (!copyprivateSyms.has_value())
1566 llvm::interleaveComma(
1567 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1568 [&](
const auto &args) {
1569 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1570 << std::get<2>(args);
1577 std::optional<ArrayAttr> copyprivateSyms) {
1578 size_t copyprivateSymsSize =
1579 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1580 if (copyprivateSymsSize != copyprivateVars.size())
1581 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1582 << copyprivateVars.size()
1583 <<
") and functions (= " << copyprivateSymsSize
1584 <<
"), both must be equal";
1585 if (!copyprivateSyms.has_value())
1588 for (
auto copyprivateVarAndSym :
1589 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1591 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1592 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1594 if (mlir::func::FuncOp mlirFuncOp =
1597 funcOp = mlirFuncOp;
1598 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1601 funcOp = llvmFuncOp;
1603 auto getNumArguments = [&] {
1604 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1607 auto getArgumentType = [&](
unsigned i) {
1608 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1613 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1614 <<
" to point to a copy function";
1616 if (getNumArguments() != 2)
1618 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1620 Type argTy = getArgumentType(0);
1621 if (argTy != getArgumentType(1))
1622 return op->
emitOpError() <<
"expected copy function " << symbolRef
1623 <<
" arguments to have the same type";
1625 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1626 if (argTy != varType)
1628 <<
"expected copy function arguments' type (" << argTy
1629 <<
") to be the same as copyprivate variable's type (" << varType
1650 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1651 parser.parseOperand(dependVars.emplace_back()) ||
1652 parser.parseColonType(dependTypes.emplace_back()))
1654 if (std::optional<ClauseTaskDepend> keywordDepend =
1655 (symbolizeClauseTaskDepend(keyword)))
1656 kindsVec.emplace_back(
1657 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1664 dependKinds = ArrayAttr::get(parser.
getContext(), kinds);
1671 std::optional<ArrayAttr> dependKinds) {
1673 for (
unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1676 p << stringifyClauseTaskDepend(
1677 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1679 <<
" -> " << dependVars[i] <<
" : " << dependTypes[i];
1685 std::optional<ArrayAttr> dependKinds,
1687 if (!dependVars.empty()) {
1688 if (!dependKinds || dependKinds->size() != dependVars.size())
1689 return op->
emitOpError() <<
"expected as many depend values"
1690 " as depend variables";
1692 if (dependKinds && !dependKinds->empty())
1693 return op->
emitOpError() <<
"unexpected depend values";
1709 IntegerAttr &hintAttr) {
1710 StringRef hintKeyword;
1716 auto parseKeyword = [&]() -> ParseResult {
1719 if (hintKeyword ==
"uncontended")
1721 else if (hintKeyword ==
"contended")
1723 else if (hintKeyword ==
"nonspeculative")
1725 else if (hintKeyword ==
"speculative")
1729 << hintKeyword <<
" is not a valid hint";
1740 IntegerAttr hintAttr) {
1741 int64_t hint = hintAttr.getInt();
1749 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1751 bool uncontended = bitn(hint, 0);
1752 bool contended = bitn(hint, 1);
1753 bool nonspeculative = bitn(hint, 2);
1754 bool speculative = bitn(hint, 3);
1758 hints.push_back(
"uncontended");
1760 hints.push_back(
"contended");
1762 hints.push_back(
"nonspeculative");
1764 hints.push_back(
"speculative");
1766 llvm::interleaveComma(hints, p);
1773 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1775 bool uncontended = bitn(hint, 0);
1776 bool contended = bitn(hint, 1);
1777 bool nonspeculative = bitn(hint, 2);
1778 bool speculative = bitn(hint, 3);
1780 if (uncontended && contended)
1781 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1782 "omp_sync_hint_contended cannot be combined";
1783 if (nonspeculative && speculative)
1784 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1785 "omp_sync_hint_speculative cannot be combined.";
1796 return (value & flag) == flag;
1804static ParseResult parseMapClause(
OpAsmParser &parser,
1805 ClauseMapFlagsAttr &mapType) {
1806 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1809 auto parseTypeAndMod = [&]() -> ParseResult {
1810 StringRef mapTypeMod;
1814 if (mapTypeMod ==
"always")
1815 mapTypeBits |= ClauseMapFlags::always;
1817 if (mapTypeMod ==
"implicit")
1818 mapTypeBits |= ClauseMapFlags::implicit;
1820 if (mapTypeMod ==
"ompx_hold")
1821 mapTypeBits |= ClauseMapFlags::ompx_hold;
1823 if (mapTypeMod ==
"close")
1824 mapTypeBits |= ClauseMapFlags::close;
1826 if (mapTypeMod ==
"present")
1827 mapTypeBits |= ClauseMapFlags::present;
1829 if (mapTypeMod ==
"to")
1830 mapTypeBits |= ClauseMapFlags::to;
1832 if (mapTypeMod ==
"from")
1833 mapTypeBits |= ClauseMapFlags::from;
1835 if (mapTypeMod ==
"tofrom")
1836 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1838 if (mapTypeMod ==
"delete")
1839 mapTypeBits |= ClauseMapFlags::del;
1841 if (mapTypeMod ==
"storage")
1842 mapTypeBits |= ClauseMapFlags::storage;
1844 if (mapTypeMod ==
"return_param")
1845 mapTypeBits |= ClauseMapFlags::return_param;
1847 if (mapTypeMod ==
"private")
1848 mapTypeBits |= ClauseMapFlags::priv;
1850 if (mapTypeMod ==
"literal")
1851 mapTypeBits |= ClauseMapFlags::literal;
1853 if (mapTypeMod ==
"attach")
1854 mapTypeBits |= ClauseMapFlags::attach;
1856 if (mapTypeMod ==
"attach_always")
1857 mapTypeBits |= ClauseMapFlags::attach_always;
1859 if (mapTypeMod ==
"attach_never")
1860 mapTypeBits |= ClauseMapFlags::attach_never;
1862 if (mapTypeMod ==
"attach_auto")
1863 mapTypeBits |= ClauseMapFlags::attach_auto;
1865 if (mapTypeMod ==
"ref_ptr")
1866 mapTypeBits |= ClauseMapFlags::ref_ptr;
1868 if (mapTypeMod ==
"ref_ptee")
1869 mapTypeBits |= ClauseMapFlags::ref_ptee;
1871 if (mapTypeMod ==
"ref_ptr_ptee")
1872 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1874 if (mapTypeMod ==
"is_device_ptr")
1875 mapTypeBits |= ClauseMapFlags::is_device_ptr;
1892 ClauseMapFlagsAttr mapType) {
1894 ClauseMapFlags mapFlags = mapType.getValue();
1899 mapTypeStrs.push_back(
"always");
1901 mapTypeStrs.push_back(
"implicit");
1903 mapTypeStrs.push_back(
"ompx_hold");
1905 mapTypeStrs.push_back(
"close");
1907 mapTypeStrs.push_back(
"present");
1916 mapTypeStrs.push_back(
"tofrom");
1918 mapTypeStrs.push_back(
"from");
1920 mapTypeStrs.push_back(
"to");
1923 mapTypeStrs.push_back(
"delete");
1925 mapTypeStrs.push_back(
"return_param");
1927 mapTypeStrs.push_back(
"storage");
1929 mapTypeStrs.push_back(
"private");
1931 mapTypeStrs.push_back(
"literal");
1933 mapTypeStrs.push_back(
"attach");
1935 mapTypeStrs.push_back(
"attach_always");
1937 mapTypeStrs.push_back(
"attach_never");
1939 mapTypeStrs.push_back(
"attach_auto");
1941 mapTypeStrs.push_back(
"ref_ptr");
1943 mapTypeStrs.push_back(
"ref_ptee");
1945 mapTypeStrs.push_back(
"ref_ptr_ptee");
1947 mapTypeStrs.push_back(
"is_device_ptr");
1948 if (mapFlags == ClauseMapFlags::none)
1949 mapTypeStrs.push_back(
"none");
1951 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1952 p << mapTypeStrs[i];
1953 if (i + 1 < mapTypeStrs.size()) {
1959static ParseResult parseMembersIndex(
OpAsmParser &parser,
1963 auto parseIndices = [&]() -> ParseResult {
1968 APInt(64, value,
false)));
1982 memberIdxs.push_back(ArrayAttr::get(parser.
getContext(), values));
1986 if (!memberIdxs.empty())
1987 membersIdx = ArrayAttr::get(parser.
getContext(), memberIdxs);
1997 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
1999 auto memberIdx = cast<ArrayAttr>(v);
2000 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
2001 p << cast<IntegerAttr>(v2).getInt();
2008 VariableCaptureKindAttr mapCaptureType) {
2009 std::string typeCapStr;
2010 llvm::raw_string_ostream typeCap(typeCapStr);
2011 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2013 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2014 typeCap <<
"ByCopy";
2015 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2016 typeCap <<
"VLAType";
2017 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2023 VariableCaptureKindAttr &mapCaptureType) {
2024 StringRef mapCaptureKey;
2028 if (mapCaptureKey ==
"This")
2029 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2030 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
2031 if (mapCaptureKey ==
"ByRef")
2032 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2033 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
2034 if (mapCaptureKey ==
"ByCopy")
2035 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2036 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2037 if (mapCaptureKey ==
"VLAType")
2038 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2039 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
2048 for (
auto mapOp : mapVars) {
2049 if (!mapOp.getDefiningOp())
2052 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2053 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2056 bool from =
mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2059 bool always =
mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2060 bool close =
mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2061 bool implicit =
mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2063 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2065 "to, from, tofrom and alloc map types are permitted");
2067 if (isa<TargetEnterDataOp>(op) && (from || del))
2068 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
2070 if (isa<TargetExitDataOp>(op) && to)
2072 "from, release and delete map types are permitted");
2074 if (isa<TargetUpdateOp>(op)) {
2077 "at least one of to or from map types must be "
2078 "specified, other map types are not permitted");
2083 "at least one of to or from map types must be "
2084 "specified, other map types are not permitted");
2087 auto updateVar = mapInfoOp.getVarPtr();
2089 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2090 (from && updateToVars.contains(updateVar))) {
2093 "either to or from map types can be specified, not both");
2096 if (always || close || implicit) {
2099 "present, mapper and iterator map type modifiers are permitted");
2102 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2104 }
else if (!isa<DeclareMapperInfoOp>(op)) {
2106 "map argument is not a map entry operation");
2114 std::optional<DenseI64ArrayAttr> privateMapIndices =
2115 targetOp.getPrivateMapsAttr();
2118 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2123 if (privateMapIndices.value().size() !=
2124 static_cast<int64_t>(privateVars.size()))
2125 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
2126 "`private_maps` attribute mismatch");
2136 StringRef clauseName,
2138 for (
Value var : vars)
2139 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2141 <<
"'" << clauseName
2142 <<
"' arguments must be defined by 'omp.map.info' ops";
2146LogicalResult MapInfoOp::verify() {
2147 if (getMapperId() &&
2149 *
this, getMapperIdAttr())) {
2164 const TargetDataOperands &clauses) {
2165 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2166 clauses.mapVars, clauses.useDeviceAddrVars,
2167 clauses.useDevicePtrVars);
2170LogicalResult TargetDataOp::verify() {
2171 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2172 getUseDeviceAddrVars().empty()) {
2173 return ::emitError(this->getLoc(),
2174 "At least one of map, use_device_ptr_vars, or "
2175 "use_device_addr_vars operand must be present");
2179 getUseDevicePtrVars())))
2183 getUseDeviceAddrVars())))
2193void TargetEnterDataOp::build(
2197 TargetEnterDataOp::build(builder, state,
2199 clauses.dependVars, clauses.device, clauses.ifExpr,
2200 clauses.mapVars, clauses.nowait);
2203LogicalResult TargetEnterDataOp::verify() {
2204 LogicalResult verifyDependVars =
2206 return failed(verifyDependVars) ? verifyDependVars
2217 TargetExitDataOp::build(builder, state,
2219 clauses.dependVars, clauses.device, clauses.ifExpr,
2220 clauses.mapVars, clauses.nowait);
2223LogicalResult TargetExitDataOp::verify() {
2224 LogicalResult verifyDependVars =
2226 return failed(verifyDependVars) ? verifyDependVars
2237 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
2238 clauses.dependVars, clauses.device, clauses.ifExpr,
2239 clauses.mapVars, clauses.nowait);
2242LogicalResult TargetUpdateOp::verify() {
2243 LogicalResult verifyDependVars =
2245 return failed(verifyDependVars) ? verifyDependVars
2254 const TargetOperands &clauses) {
2258 TargetOp::build(builder, state, {}, {},
2260 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
2261 clauses.hostEvalVars, clauses.ifExpr,
2263 nullptr, clauses.isDevicePtrVars,
2264 clauses.mapVars, clauses.nowait, clauses.privateVars,
2266 clauses.privateNeedsBarrier, clauses.threadLimitVars,
2270LogicalResult TargetOp::verify() {
2275 getHasDeviceAddrVars())))
2284LogicalResult TargetOp::verifyRegions() {
2285 auto teamsOps = getOps<TeamsOp>();
2286 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2287 return emitError(
"target containing multiple 'omp.teams' nested ops");
2290 Operation *capturedOp = getInnermostCapturedOmpOp();
2291 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2292 for (
Value hostEvalArg :
2293 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2295 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
2297 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2298 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2299 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2302 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
2303 "and 'thread_limit' in 'omp.teams'";
2305 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
2306 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2307 parallelOp->isAncestor(capturedOp) &&
2308 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2312 <<
"host_eval argument only legal as 'num_threads' in "
2313 "'omp.parallel' when representing target SPMD";
2315 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2316 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2317 loopNestOp.getOperation() == capturedOp &&
2318 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2319 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2320 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2323 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2324 "and steps in 'omp.loop_nest' when trip count "
2325 "must be evaluated in the host";
2328 return emitOpError() <<
"host_eval argument illegal use in '"
2329 << user->getName() <<
"' operation";
2338 assert(rootOp &&
"expected valid operation");
2355 bool isOmpDialect = op->
getDialect() == ompDialect;
2357 if (!isOmpDialect || !hasRegions)
2364 if (checkSingleMandatoryExec) {
2369 if (successor->isReachable(parentBlock))
2372 for (
Block &block : *parentRegion)
2374 !domInfo.
dominates(parentBlock, &block))
2381 if (&sibling != op && !siblingAllowedFn(&sibling))
2394Operation *TargetOp::getInnermostCapturedOmpOp() {
2395 auto *ompDialect =
getContext()->getLoadedDialect<omp::OpenMPDialect>();
2407 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2410 memOp.getEffects(effects);
2411 return !llvm::any_of(
2413 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2414 isa<SideEffects::AutomaticAllocationScopeResource>(
2424 WsloopOp *wsLoopOp) {
2426 if (!teamsOp.getNumTeamsUpperVars().empty())
2430 if (teamsOp.getNumReductionVars())
2432 if (wsLoopOp->getNumReductionVars())
2436 OffloadModuleInterface offloadMod =
2440 auto ompFlags = offloadMod.getFlags();
2443 return ompFlags.getAssumeTeamsOversubscription() &&
2444 ompFlags.getAssumeThreadsOversubscription();
2447TargetRegionFlags TargetOp::getKernelExecFlags(
Operation *capturedOp) {
2452 assert((!capturedOp ||
2453 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2454 "unexpected captured op");
2457 if (!isa_and_present<LoopNestOp>(capturedOp))
2458 return TargetRegionFlags::generic;
2462 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2463 assert(!loopWrappers.empty());
2465 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2466 if (isa<SimdOp>(innermostWrapper))
2467 innermostWrapper = std::next(innermostWrapper);
2469 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2470 if (numWrappers != 1 && numWrappers != 2)
2471 return TargetRegionFlags::generic;
2474 if (numWrappers == 2) {
2475 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2477 return TargetRegionFlags::generic;
2479 innermostWrapper = std::next(innermostWrapper);
2480 if (!isa<DistributeOp>(innermostWrapper))
2481 return TargetRegionFlags::generic;
2484 if (!isa_and_present<ParallelOp>(parallelOp))
2485 return TargetRegionFlags::generic;
2487 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->
getParentOp());
2489 return TargetRegionFlags::generic;
2491 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2492 TargetRegionFlags
result =
2493 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2500 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2502 if (!isa_and_present<TeamsOp>(teamsOp))
2503 return TargetRegionFlags::generic;
2505 if (teamsOp->
getParentOp() != targetOp.getOperation())
2506 return TargetRegionFlags::generic;
2508 if (isa<LoopOp>(innermostWrapper))
2509 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2519 Dialect *ompDialect = targetOp->getDialect();
2523 return sibling && (ompDialect != sibling->
getDialect() ||
2527 TargetRegionFlags
result =
2528 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2533 while (nestedCapture->
getParentOp() != capturedOp)
2536 return isa<ParallelOp>(nestedCapture) ?
result | TargetRegionFlags::spmd
2540 else if (isa<WsloopOp>(innermostWrapper)) {
2542 if (!isa_and_present<ParallelOp>(parallelOp))
2543 return TargetRegionFlags::generic;
2545 if (parallelOp->
getParentOp() == targetOp.getOperation())
2546 return TargetRegionFlags::spmd;
2549 return TargetRegionFlags::generic;
2558 ParallelOp::build(builder, state,
ValueRange(),
2570 const ParallelOperands &clauses) {
2572 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2573 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2575 clauses.privateNeedsBarrier, clauses.procBindKind,
2576 clauses.reductionMod, clauses.reductionVars,
2581template <
typename OpType>
2583 auto privateVars = op.getPrivateVars();
2584 auto privateSyms = op.getPrivateSymsAttr();
2586 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2589 auto numPrivateVars = privateVars.size();
2590 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2592 if (numPrivateVars != numPrivateSyms)
2593 return op.emitError() <<
"inconsistent number of private variables and "
2594 "privatizer op symbols, private vars: "
2596 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2598 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2599 Type varType = std::get<0>(privateVarInfo).getType();
2600 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2601 PrivateClauseOp privatizerOp =
2604 if (privatizerOp ==
nullptr)
2605 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2606 << privateSym <<
"'";
2608 Type privatizerType = privatizerOp.getArgType();
2610 if (privatizerType && (varType != privatizerType))
2611 return op.emitError()
2612 <<
"type mismatch between a "
2613 << (privatizerOp.getDataSharingType() ==
2614 DataSharingClauseType::Private
2617 <<
" variable and its privatizer op, var type: " << varType
2618 <<
" vs. privatizer op type: " << privatizerType;
2624LogicalResult ParallelOp::verify() {
2625 if (getAllocateVars().size() != getAllocatorVars().size())
2627 "expected equal sizes for allocate and allocator variables");
2633 getReductionByref());
2636LogicalResult ParallelOp::verifyRegions() {
2637 auto distChildOps = getOps<DistributeOp>();
2638 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2639 if (numDistChildOps > 1)
2641 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2643 if (numDistChildOps == 1) {
2646 <<
"'omp.composite' attribute missing from composite operation";
2648 auto *ompDialect =
getContext()->getLoadedDialect<OpenMPDialect>();
2649 Operation &distributeOp = **distChildOps.begin();
2651 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2655 return emitError() <<
"unexpected OpenMP operation inside of composite "
2657 << childOp.getName();
2659 }
else if (isComposite()) {
2661 <<
"'omp.composite' attribute present in non-composite operation";
2678 const TeamsOperands &clauses) {
2682 builder, state, clauses.allocateVars, clauses.allocatorVars,
2683 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
2685 nullptr, clauses.reductionMod,
2686 clauses.reductionVars,
2688 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2695 if (numTeamsLower) {
2696 if (numTeamsUpperVars.size() != 1)
2698 "expected exactly one num_teams upper bound when lower bound is "
2702 "expected num_teams upper bound and lower bound to be "
2709LogicalResult TeamsOp::verify() {
2718 return emitError(
"expected to be nested inside of omp.target or not nested "
2719 "in any OpenMP dialect operations");
2723 this->getNumTeamsUpperVars())))
2727 if (getAllocateVars().size() != getAllocatorVars().size())
2729 "expected equal sizes for allocate and allocator variables");
2732 getReductionByref());
2740 return getParentOp().getPrivateVars();
2744 return getParentOp().getReductionVars();
2752 const SectionsOperands &clauses) {
2755 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2758 clauses.reductionMod, clauses.reductionVars,
2763LogicalResult SectionsOp::verify() {
2764 if (getAllocateVars().size() != getAllocatorVars().size())
2766 "expected equal sizes for allocate and allocator variables");
2769 getReductionByref());
2772LogicalResult SectionsOp::verifyRegions() {
2773 for (
auto &inst : *getRegion().begin()) {
2774 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2776 <<
"expected omp.section op or terminator op inside region";
2788 const SingleOperands &clauses) {
2791 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2792 clauses.copyprivateVars,
2793 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2798LogicalResult SingleOp::verify() {
2800 if (getAllocateVars().size() != getAllocatorVars().size())
2802 "expected equal sizes for allocate and allocator variables");
2805 getCopyprivateSyms());
2813 const WorkshareOperands &clauses) {
2814 WorkshareOp::build(builder, state, clauses.nowait);
2821LogicalResult WorkshareLoopWrapperOp::verify() {
2822 if (!(*this)->getParentOfType<WorkshareOp>())
2823 return emitOpError() <<
"must be nested in an omp.workshare";
2827LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2828 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2830 return emitOpError() <<
"expected to be a standalone loop wrapper";
2839LogicalResult LoopWrapperInterface::verifyImpl() {
2843 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2844 "and `SingleBlock` traits";
2847 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2850 if (range_size(region.
getOps()) != 1)
2852 <<
"loop wrapper does not contain exactly one nested op";
2855 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2856 return emitOpError() <<
"nested in loop wrapper is not another loop "
2857 "wrapper or `omp.loop_nest`";
2867 const LoopOperands &clauses) {
2870 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2872 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2873 clauses.reductionMod, clauses.reductionVars,
2878LogicalResult LoopOp::verify() {
2880 getReductionByref());
2883LogicalResult LoopOp::verifyRegions() {
2884 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2886 return emitOpError() <<
"expected to be a standalone loop wrapper";
2897 build(builder, state, {}, {},
2900 false,
nullptr,
nullptr,
2901 nullptr, {},
nullptr,
2912 const WsloopOperands &clauses) {
2917 {}, {}, clauses.linearVars,
2918 clauses.linearStepVars, clauses.linearVarTypes, clauses.nowait,
2919 clauses.order, clauses.orderMod, clauses.ordered, clauses.privateVars,
2920 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2921 clauses.reductionMod, clauses.reductionVars,
2923 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2924 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2927LogicalResult WsloopOp::verify() {
2928 if (getLinearVars().size() &&
2929 getLinearVarTypes().value().size() != getLinearVars().size())
2930 return emitError() <<
"Ill-formed type attributes for linear variables";
2932 getReductionByref());
2935LogicalResult WsloopOp::verifyRegions() {
2936 bool isCompositeChildLeaf =
2937 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2939 if (LoopWrapperInterface nested = getNestedWrapper()) {
2942 <<
"'omp.composite' attribute missing from composite wrapper";
2946 if (!isa<SimdOp>(nested))
2947 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
2949 }
else if (isComposite() && !isCompositeChildLeaf) {
2951 <<
"'omp.composite' attribute present in non-composite wrapper";
2952 }
else if (!isComposite() && isCompositeChildLeaf) {
2954 <<
"'omp.composite' attribute missing from composite wrapper";
2965 const SimdOperands &clauses) {
2968 builder, state, clauses.alignedVars,
2970 clauses.linearVars, clauses.linearStepVars, clauses.linearVarTypes,
2971 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2972 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
2973 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
2979LogicalResult SimdOp::verify() {
2980 if (getSimdlen().has_value() && getSafelen().has_value() &&
2981 getSimdlen().value() > getSafelen().value())
2983 <<
"simdlen clause and safelen clause are both present, but the "
2984 "simdlen value is not less than or equal to safelen value";
2992 bool isCompositeChildLeaf =
2993 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2995 if (!isComposite() && isCompositeChildLeaf)
2997 <<
"'omp.composite' attribute missing from composite wrapper";
2999 if (isComposite() && !isCompositeChildLeaf)
3001 <<
"'omp.composite' attribute present in non-composite wrapper";
3005 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3007 for (
const Attribute &sym : *privateSyms) {
3008 auto symRef = cast<SymbolRefAttr>(sym);
3009 omp::PrivateClauseOp privatizer =
3011 getOperation(), symRef);
3013 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
3014 if (privatizer.getDataSharingType() ==
3015 DataSharingClauseType::FirstPrivate)
3016 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
3020 if (getLinearVars().size() &&
3021 getLinearVarTypes().value().size() != getLinearVars().size())
3022 return emitError() <<
"Ill-formed type attributes for linear variables";
3026LogicalResult SimdOp::verifyRegions() {
3027 if (getNestedWrapper())
3028 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
3038 const DistributeOperands &clauses) {
3039 DistributeOp::build(builder, state, clauses.allocateVars,
3040 clauses.allocatorVars, clauses.distScheduleStatic,
3041 clauses.distScheduleChunkSize, clauses.order,
3042 clauses.orderMod, clauses.privateVars,
3044 clauses.privateNeedsBarrier);
3047LogicalResult DistributeOp::verify() {
3048 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3050 "dist_schedule_static being present";
3052 if (getAllocateVars().size() != getAllocatorVars().size())
3054 "expected equal sizes for allocate and allocator variables");
3059LogicalResult DistributeOp::verifyRegions() {
3060 if (LoopWrapperInterface nested = getNestedWrapper()) {
3063 <<
"'omp.composite' attribute missing from composite wrapper";
3066 if (isa<WsloopOp>(nested)) {
3068 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3069 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3070 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
3071 "when a composite 'omp.parallel' is the direct "
3074 }
else if (!isa<SimdOp>(nested))
3075 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
3077 }
else if (isComposite()) {
3079 <<
"'omp.composite' attribute present in non-composite wrapper";
3089LogicalResult DeclareMapperInfoOp::verify() {
3093LogicalResult DeclareMapperOp::verifyRegions() {
3094 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3095 getRegion().getBlocks().front().getTerminator()))
3096 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
3105LogicalResult DeclareReductionOp::verifyRegions() {
3106 if (!getAllocRegion().empty()) {
3107 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3108 if (yieldOp.getResults().size() != 1 ||
3109 yieldOp.getResults().getTypes()[0] !=
getType())
3110 return emitOpError() <<
"expects alloc region to yield a value "
3111 "of the reduction type";
3115 if (getInitializerRegion().empty())
3116 return emitOpError() <<
"expects non-empty initializer region";
3117 Block &initializerEntryBlock = getInitializerRegion().
front();
3120 if (!getAllocRegion().empty())
3121 return emitOpError() <<
"expects two arguments to the initializer region "
3122 "when an allocation region is used";
3124 if (getAllocRegion().empty())
3125 return emitOpError() <<
"expects one argument to the initializer region "
3126 "when no allocation region is used";
3129 <<
"expects one or two arguments to the initializer region";
3133 if (arg.getType() !=
getType())
3134 return emitOpError() <<
"expects initializer region argument to match "
3135 "the reduction type";
3137 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3138 if (yieldOp.getResults().size() != 1 ||
3139 yieldOp.getResults().getTypes()[0] !=
getType())
3140 return emitOpError() <<
"expects initializer region to yield a value "
3141 "of the reduction type";
3144 if (getReductionRegion().empty())
3145 return emitOpError() <<
"expects non-empty reduction region";
3146 Block &reductionEntryBlock = getReductionRegion().
front();
3151 return emitOpError() <<
"expects reduction region with two arguments of "
3152 "the reduction type";
3153 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3154 if (yieldOp.getResults().size() != 1 ||
3155 yieldOp.getResults().getTypes()[0] !=
getType())
3156 return emitOpError() <<
"expects reduction region to yield a value "
3157 "of the reduction type";
3160 if (!getAtomicReductionRegion().empty()) {
3161 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
3165 return emitOpError() <<
"expects atomic reduction region with two "
3166 "arguments of the same type";
3167 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3170 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
3171 return emitOpError() <<
"expects atomic reduction region arguments to "
3172 "be accumulators containing the reduction type";
3175 if (getCleanupRegion().empty())
3177 Block &cleanupEntryBlock = getCleanupRegion().
front();
3180 return emitOpError() <<
"expects cleanup region with one argument "
3181 "of the reduction type";
3191 const TaskOperands &clauses) {
3193 TaskOp::build(builder, state, clauses.iterated, clauses.affinityVars,
3194 clauses.allocateVars, clauses.allocatorVars,
3195 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3196 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3198 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3199 clauses.priority, clauses.privateVars,
3201 clauses.privateNeedsBarrier, clauses.untied,
3202 clauses.eventHandle);
3205LogicalResult TaskOp::verify() {
3206 LogicalResult verifyDependVars =
3208 return failed(verifyDependVars)
3211 getInReductionVars(),
3212 getInReductionByref());
3220 const TaskgroupOperands &clauses) {
3222 TaskgroupOp::build(builder, state, clauses.allocateVars,
3223 clauses.allocatorVars, clauses.taskReductionVars,
3228LogicalResult TaskgroupOp::verify() {
3230 getTaskReductionVars(),
3231 getTaskReductionByref());
3239 const TaskloopOperands &clauses) {
3242 builder, state, clauses.allocateVars, clauses.allocatorVars,
3243 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3244 clauses.inReductionVars,
3246 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3247 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3248 clauses.privateVars,
3250 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3255LogicalResult TaskloopOp::verify() {
3256 if (getAllocateVars().size() != getAllocatorVars().size())
3258 "expected equal sizes for allocate and allocator variables");
3260 getReductionVars(), getReductionByref())) ||
3262 getInReductionVars(),
3263 getInReductionByref())))
3266 if (!getReductionVars().empty() && getNogroup())
3267 return emitError(
"if a reduction clause is present on the taskloop "
3268 "directive, the nogroup clause must not be specified");
3269 for (
auto var : getReductionVars()) {
3270 if (llvm::is_contained(getInReductionVars(), var))
3271 return emitError(
"the same list item cannot appear in both a reduction "
3272 "and an in_reduction clause");
3275 if (getGrainsize() && getNumTasks()) {
3277 "the grainsize clause and num_tasks clause are mutually exclusive and "
3278 "may not appear on the same taskloop directive");
3284LogicalResult TaskloopOp::verifyRegions() {
3285 if (LoopWrapperInterface nested = getNestedWrapper()) {
3288 <<
"'omp.composite' attribute missing from composite wrapper";
3292 if (!isa<SimdOp>(nested))
3293 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3294 }
else if (isComposite()) {
3296 <<
"'omp.composite' attribute present in non-composite wrapper";
3320 for (
auto &iv : ivs)
3321 iv.type = loopVarType;
3326 result.addAttribute(
"loop_inclusive", UnitAttr::get(ctx));
3342 "collapse_num_loops",
3347 auto parseTiles = [&]() -> ParseResult {
3351 tiles.push_back(
tile);
3360 if (tiles.size() > 0)
3379 Region ®ion = getRegion();
3381 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3382 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3383 if (getLoopInclusive())
3385 p <<
"step (" << getLoopSteps() <<
") ";
3386 if (
int64_t numCollapse = getCollapseNumLoops())
3387 if (numCollapse > 1)
3388 p <<
"collapse(" << numCollapse <<
") ";
3391 p <<
"tiles(" << tiles.value() <<
") ";
3397 const LoopNestOperands &clauses) {
3399 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3400 clauses.loopLowerBounds, clauses.loopUpperBounds,
3401 clauses.loopSteps, clauses.loopInclusive,
3405LogicalResult LoopNestOp::verify() {
3406 if (getLoopLowerBounds().empty())
3407 return emitOpError() <<
"must represent at least one loop";
3409 if (getLoopLowerBounds().size() != getIVs().size())
3410 return emitOpError() <<
"number of range arguments and IVs do not match";
3412 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3413 if (lb.getType() != iv.getType())
3415 <<
"range argument type does not match corresponding IV type";
3418 uint64_t numIVs = getIVs().size();
3420 if (
const auto &numCollapse = getCollapseNumLoops())
3421 if (numCollapse > numIVs)
3423 <<
"collapse value is larger than the number of loops";
3426 if (tiles.value().size() > numIVs)
3427 return emitOpError() <<
"too few canonical loops for tile dimensions";
3429 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3430 return emitOpError() <<
"expects parent op to be a loop wrapper";
3435void LoopNestOp::gatherWrappers(
3438 while (
auto wrapper =
3439 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3440 wrappers.push_back(wrapper);
3449std::tuple<NewCliOp, OpOperand *, OpOperand *>
3455 return {{},
nullptr,
nullptr};
3458 "Unexpected type of cli");
3464 auto op = cast<LoopTransformationInterface>(use.getOwner());
3466 unsigned opnum = use.getOperandNumber();
3467 if (op.isGeneratee(opnum)) {
3468 assert(!gen &&
"Each CLI may have at most one def");
3470 }
else if (op.isApplyee(opnum)) {
3471 assert(!cons &&
"Each CLI may have at most one consumer");
3474 llvm_unreachable(
"Unexpected operand for a CLI");
3478 return {create, gen, cons};
3501 std::string cliName{
"cli"};
3505 .Case([&](CanonicalLoopOp op) {
3508 .Case([&](UnrollHeuristicOp op) -> std::string {
3509 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3511 .Case([&](FuseOp op) -> std::string {
3512 unsigned opnum =
generator->getOperandNumber();
3515 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3516 return "canonloop_fuse";
3520 .Case([&](TileOp op) -> std::string {
3521 auto [generateesFirst, generateesCount] =
3522 op.getGenerateesODSOperandIndexAndLength();
3523 unsigned firstGrid = generateesFirst;
3524 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3525 unsigned end = generateesFirst + generateesCount;
3526 unsigned opnum =
generator->getOperandNumber();
3528 if (firstGrid <= opnum && opnum < firstIntratile) {
3529 unsigned gridnum = opnum - firstGrid + 1;
3530 return (
"grid" + Twine(gridnum)).str();
3532 if (firstIntratile <= opnum && opnum < end) {
3533 unsigned intratilenum = opnum - firstIntratile + 1;
3534 return (
"intratile" + Twine(intratilenum)).str();
3536 llvm_unreachable(
"Unexpected generatee argument");
3538 .DefaultUnreachable(
"TODO: Custom name for this operation");
3541 setNameFn(
result, cliName);
3544LogicalResult NewCliOp::verify() {
3545 Value cli = getResult();
3548 "Unexpected type of cli");
3554 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3556 unsigned opnum = use.getOperandNumber();
3557 if (op.isGeneratee(opnum)) {
3560 emitOpError(
"CLI must have at most one generator");
3562 .
append(
"first generator here:");
3564 .
append(
"second generator here:");
3569 }
else if (op.isApplyee(opnum)) {
3572 emitOpError(
"CLI must have at most one consumer");
3574 .
append(
"first consumer here:")
3578 .
append(
"second consumer here:")
3585 llvm_unreachable(
"Unexpected operand for a CLI");
3593 .
append(
"see consumer here: ")
3616 setNameFn(&getRegion().front(),
"body_entry");
3619void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
3627 p <<
'(' << getCli() <<
')';
3628 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
3629 <<
" in range(" << getTripCount() <<
") ";
3639 CanonicalLoopInfoType cliType =
3640 CanonicalLoopInfoType::get(parser.
getContext());
3665 if (parser.
parseRegion(*region, {inductionVariable}))
3670 result.operands.append(cliOperand);
3676 return mlir::success();
3679LogicalResult CanonicalLoopOp::verify() {
3682 if (!getRegion().empty()) {
3683 Region ®ion = getRegion();
3686 "Canonical loop region must have exactly one argument");
3690 "Region argument must be the same type as the trip count");
3696Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
3698std::pair<unsigned, unsigned>
3699CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3704std::pair<unsigned, unsigned>
3705CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3706 return getODSOperandIndexAndLength(odsIndex_cli);
3720 p <<
'(' << getApplyee() <<
')';
3727 auto cliType = CanonicalLoopInfoType::get(parser.
getContext());
3750 return mlir::success();
3753std::pair<unsigned, unsigned>
3754UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3755 return getODSOperandIndexAndLength(odsIndex_applyee);
3758std::pair<unsigned, unsigned>
3759UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3770 if (!generatees.empty())
3771 p <<
'(' << llvm::interleaved(generatees) <<
')';
3773 if (!applyees.empty())
3774 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
3816 bool isOnlyCanonLoops =
true;
3818 for (
Value applyee : op.getApplyees()) {
3819 auto [create, gen, cons] =
decodeCli(applyee);
3822 return op.emitOpError() <<
"applyee CLI has no generator";
3824 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3825 canonLoops.push_back(loop);
3827 isOnlyCanonLoops =
false;
3832 if (!isOnlyCanonLoops)
3836 for (
auto i : llvm::seq<int>(1, canonLoops.size())) {
3837 auto parentLoop = canonLoops[i - 1];
3838 auto loop = canonLoops[i];
3840 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
3841 return op.emitOpError()
3842 <<
"tiled loop nest must be nested within each other";
3844 parentIVs.insert(parentLoop.getInductionVar());
3849 bool isPerfectlyNested = [&]() {
3850 auto &parentBody = parentLoop.getRegion();
3851 if (!parentBody.hasOneBlock())
3853 auto &parentBlock = parentBody.getBlocks().front();
3855 auto nestedLoopIt = parentBlock.begin();
3856 if (nestedLoopIt == parentBlock.end() ||
3857 (&*nestedLoopIt != loop.getOperation()))
3860 auto termIt = std::next(nestedLoopIt);
3861 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3864 if (std::next(termIt) != parentBlock.end())
3869 if (!isPerfectlyNested)
3870 return op.emitOpError() <<
"tiled loop nest must be perfectly nested";
3872 if (parentIVs.contains(loop.getTripCount()))
3873 return op.emitOpError() <<
"tiled loop nest must be rectangular";
3890LogicalResult TileOp::verify() {
3891 if (getApplyees().empty())
3892 return emitOpError() <<
"must apply to at least one loop";
3894 if (getSizes().size() != getApplyees().size())
3895 return emitOpError() <<
"there must be one tile size for each applyee";
3897 if (!getGeneratees().empty() &&
3898 2 * getSizes().size() != getGeneratees().size())
3900 <<
"expecting two times the number of generatees than applyees";
3905std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
3906 return getODSOperandIndexAndLength(odsIndex_applyees);
3909std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
3910 return getODSOperandIndexAndLength(odsIndex_generatees);
3920 if (!generatees.empty())
3921 p <<
'(' << llvm::interleaved(generatees) <<
')';
3923 if (!applyees.empty())
3924 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
3927LogicalResult FuseOp::verify() {
3928 if (getApplyees().size() < 2)
3929 return emitOpError() <<
"must apply to at least two loops";
3931 if (getFirst().has_value() && getCount().has_value()) {
3932 int64_t first = getFirst().value();
3933 int64_t count = getCount().value();
3934 if ((
unsigned)(first + count - 1) > getApplyees().size())
3935 return emitOpError() <<
"the numbers of applyees must be at least first "
3936 "minus one plus count attributes";
3937 if (!getGeneratees().empty() &&
3938 getGeneratees().size() != getApplyees().size() + 1 - count)
3939 return emitOpError() <<
"the number of generatees must be the number of "
3940 "aplyees plus one minus count";
3943 if (!getGeneratees().empty() && getGeneratees().size() != 1)
3945 <<
"in a complete fuse the number of generatees must be exactly 1";
3947 for (
auto &&applyee : getApplyees()) {
3948 auto [create, gen, cons] =
decodeCli(applyee);
3951 return emitOpError() <<
"applyee CLI has no generator";
3952 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3955 <<
"currently only supports omp.canonical_loop as applyee";
3959std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
3960 return getODSOperandIndexAndLength(odsIndex_applyees);
3963std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
3964 return getODSOperandIndexAndLength(odsIndex_generatees);
3972 const CriticalDeclareOperands &clauses) {
3973 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3976LogicalResult CriticalDeclareOp::verify() {
3981 if (getNameAttr()) {
3982 SymbolRefAttr symbolRef = getNameAttr();
3986 return emitOpError() <<
"expected symbol reference " << symbolRef
3987 <<
" to point to a critical declaration";
4007 return op.
emitOpError() <<
"must be nested inside of a loop";
4011 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4012 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4014 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
4015 "have an ordered clause";
4017 if (hasRegion && orderedAttr.getInt() != 0)
4018 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
4019 "have a parameter present";
4021 if (!hasRegion && orderedAttr.getInt() == 0)
4022 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
4023 "have a parameter present";
4024 }
else if (!isa<SimdOp>(wrapper)) {
4025 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
4026 "or worksharing simd loop";
4032 const OrderedOperands &clauses) {
4033 OrderedOp::build(builder, state, clauses.doacrossDependType,
4034 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4037LogicalResult OrderedOp::verify() {
4041 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4042 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4043 return emitOpError() <<
"number of variables in depend clause does not "
4044 <<
"match number of iteration variables in the "
4051 const OrderedRegionOperands &clauses) {
4052 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4062 const TaskwaitOperands &clauses) {
4064 TaskwaitOp::build(builder, state,
nullptr,
4072LogicalResult AtomicReadOp::verify() {
4073 if (verifyCommon().
failed())
4074 return mlir::failure();
4076 if (
auto mo = getMemoryOrder()) {
4077 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4078 *mo == ClauseMemoryOrderKind::Release) {
4080 "memory-order must not be acq_rel or release for atomic reads");
4090LogicalResult AtomicWriteOp::verify() {
4091 if (verifyCommon().
failed())
4092 return mlir::failure();
4094 if (
auto mo = getMemoryOrder()) {
4095 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4096 *mo == ClauseMemoryOrderKind::Acquire) {
4098 "memory-order must not be acq_rel or acquire for atomic writes");
4108LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4114 if (
Value writeVal = op.getWriteOpVal()) {
4116 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4122LogicalResult AtomicUpdateOp::verify() {
4123 if (verifyCommon().
failed())
4124 return mlir::failure();
4126 if (
auto mo = getMemoryOrder()) {
4127 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4128 *mo == ClauseMemoryOrderKind::Acquire) {
4130 "memory-order must not be acq_rel or acquire for atomic updates");
4137LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4143AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4144 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4146 return dyn_cast<AtomicReadOp>(getSecondOp());
4149AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4150 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4152 return dyn_cast<AtomicWriteOp>(getSecondOp());
4155AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4156 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4158 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4161LogicalResult AtomicCaptureOp::verify() {
4165LogicalResult AtomicCaptureOp::verifyRegions() {
4166 if (verifyRegionsCommon().
failed())
4167 return mlir::failure();
4169 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
4171 "operations inside capture region must not have hint clause");
4173 if (getFirstOp()->getAttr(
"memory_order") ||
4174 getSecondOp()->getAttr(
"memory_order"))
4176 "operations inside capture region must not have memory_order clause");
4185 const CancelOperands &clauses) {
4186 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4199LogicalResult CancelOp::verify() {
4200 ClauseCancellationConstructType cct = getCancelDirective();
4203 if (!structuralParent)
4204 return emitOpError() <<
"Orphaned cancel construct";
4206 if ((cct == ClauseCancellationConstructType::Parallel) &&
4207 !mlir::isa<ParallelOp>(structuralParent)) {
4208 return emitOpError() <<
"cancel parallel must appear "
4209 <<
"inside a parallel region";
4211 if (cct == ClauseCancellationConstructType::Loop) {
4214 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
4218 <<
"cancel loop must appear inside a worksharing-loop region";
4220 if (wsloopOp.getNowaitAttr()) {
4221 return emitError() <<
"A worksharing construct that is canceled "
4222 <<
"must not have a nowait clause";
4224 if (wsloopOp.getOrderedAttr()) {
4225 return emitError() <<
"A worksharing construct that is canceled "
4226 <<
"must not have an ordered clause";
4229 }
else if (cct == ClauseCancellationConstructType::Sections) {
4233 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
4235 return emitOpError() <<
"cancel sections must appear "
4236 <<
"inside a sections region";
4238 if (sectionsOp.getNowait()) {
4239 return emitError() <<
"A sections construct that is canceled "
4240 <<
"must not have a nowait clause";
4243 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4244 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4245 !mlir::isa<omp::TaskloopOp>(structuralParent->
getParentOp()))) {
4246 return emitOpError() <<
"cancel taskgroup must appear "
4247 <<
"inside a task region";
4257 const CancellationPointOperands &clauses) {
4258 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4261LogicalResult CancellationPointOp::verify() {
4262 ClauseCancellationConstructType cct = getCancelDirective();
4265 if (!structuralParent)
4266 return emitOpError() <<
"Orphaned cancellation point";
4268 if ((cct == ClauseCancellationConstructType::Parallel) &&
4269 !mlir::isa<ParallelOp>(structuralParent)) {
4270 return emitOpError() <<
"cancellation point parallel must appear "
4271 <<
"inside a parallel region";
4275 if ((cct == ClauseCancellationConstructType::Loop) &&
4276 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
4277 return emitOpError() <<
"cancellation point loop must appear "
4278 <<
"inside a worksharing-loop region";
4280 if ((cct == ClauseCancellationConstructType::Sections) &&
4281 !mlir::isa<omp::SectionOp>(structuralParent)) {
4282 return emitOpError() <<
"cancellation point sections must appear "
4283 <<
"inside a sections region";
4285 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4286 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4287 !mlir::isa<omp::TaskloopOp>(structuralParent->
getParentOp()))) {
4288 return emitOpError() <<
"cancellation point taskgroup must appear "
4289 <<
"inside a task region";
4298LogicalResult MapBoundsOp::verify() {
4299 auto extent = getExtent();
4301 if (!extent && !upperbound)
4302 return emitError(
"expected extent or upperbound.");
4309 PrivateClauseOp::build(
4310 odsBuilder, odsState, symName, type,
4311 DataSharingClauseTypeAttr::get(odsBuilder.
getContext(),
4312 DataSharingClauseType::Private));
4315LogicalResult PrivateClauseOp::verifyRegions() {
4316 Type argType = getArgType();
4317 auto verifyTerminator = [&](
Operation *terminator,
4318 bool yieldsValue) -> LogicalResult {
4322 if (!llvm::isa<YieldOp>(terminator))
4324 <<
"expected exit block terminator to be an `omp.yield` op.";
4326 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4327 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4330 if (yieldedTypes.empty())
4334 <<
"Did not expect any values to be yielded.";
4337 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4341 <<
"Invalid yielded value. Expected type: " << argType
4344 if (yieldedTypes.empty())
4347 error << yieldedTypes;
4353 StringRef regionName,
4354 bool yieldsValue) -> LogicalResult {
4355 assert(!region.
empty());
4359 <<
"`" << regionName <<
"`: "
4360 <<
"expected " << expectedNumArgs
4363 for (
Block &block : region) {
4365 if (!block.mightHaveTerminator())
4368 if (
failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4376 for (
Region *region : getRegions())
4377 for (
Type ty : region->getArgumentTypes())
4379 return emitError() <<
"Region argument type mismatch: got " << ty
4380 <<
" expected " << argType <<
".";
4383 if (!initRegion.
empty() &&
4388 DataSharingClauseType dsType = getDataSharingType();
4390 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4391 return emitError(
"`private` clauses do not require a `copy` region.");
4393 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4395 "`firstprivate` clauses require at least a `copy` region.");
4397 if (dsType == DataSharingClauseType::FirstPrivate &&
4402 if (!getDeallocRegion().empty() &&
4415 const MaskedOperands &clauses) {
4416 MaskedOp::build(builder, state, clauses.filteredThreadId);
4424 const ScanOperands &clauses) {
4425 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4428LogicalResult ScanOp::verify() {
4429 if (hasExclusiveVars() == hasInclusiveVars())
4431 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4432 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4433 if (parentWsLoopOp.getReductionModAttr() &&
4434 parentWsLoopOp.getReductionModAttr().getValue() ==
4435 ReductionModifier::inscan)
4438 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4439 if (parentSimdOp.getReductionModAttr() &&
4440 parentSimdOp.getReductionModAttr().getValue() ==
4441 ReductionModifier::inscan)
4444 return emitError(
"SCAN directive needs to be enclosed within a parent "
4445 "worksharing loop construct or SIMD construct with INSCAN "
4446 "reduction modifier");
4451LogicalResult AllocateDirOp::verify() {
4452 std::optional<uint64_t> align = this->getAlign();
4454 if (align.has_value()) {
4455 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4456 return emitError() <<
"ALIGN value : " << align.value()
4457 <<
" must be power of 2";
4467mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4468 return getInTypeAttr().getValue();
4477 bool hasOperands =
false;
4478 std::int32_t typeparamsSize = 0;
4484 return mlir::failure();
4486 return mlir::failure();
4488 return mlir::failure();
4492 return mlir::failure();
4493 result.addAttribute(
"in_type", mlir::TypeAttr::get(intype));
4500 return mlir::failure();
4501 typeparamsSize = operands.size();
4504 std::int32_t shapeSize = 0;
4508 return mlir::failure();
4509 shapeSize = operands.size() - typeparamsSize;
4511 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4512 typeVec.push_back(idxTy);
4518 return mlir::failure();
4523 return mlir::failure();
4526 result.addAttribute(
"operandSegmentSizes",
4530 return mlir::failure();
4531 return mlir::success();
4546 if (!getTypeparams().empty()) {
4547 p <<
'(' << getTypeparams() <<
" : " << getTypeparams().getTypes() <<
')';
4554 {
"in_type",
"operandSegmentSizes"});
4557llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4559 if (!mlir::dyn_cast<IntegerType>(outType))
4561 return mlir::success();
4568LogicalResult WorkdistributeOp::verify() {
4570 Region ®ion = getRegion();
4575 if (entryBlock.
empty())
4576 return emitOpError(
"region must contain a structured block");
4578 bool hasTerminator =
false;
4579 for (
Block &block : region) {
4580 if (isa<TerminatorOp>(block.back())) {
4581 if (hasTerminator) {
4582 return emitOpError(
"region must have exactly one terminator");
4584 hasTerminator =
true;
4587 if (!hasTerminator) {
4588 return emitOpError(
"region must be terminated with omp.terminator");
4592 if (isa<BarrierOp>(op)) {
4594 "explicit barriers are not allowed in workdistribute region");
4597 if (isa<ParallelOp>(op)) {
4599 "nested parallel constructs not allowed in workdistribute");
4601 if (isa<TeamsOp>(op)) {
4603 "nested teams constructs not allowed in workdistribute");
4607 if (walkResult.wasInterrupted())
4611 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4612 return emitOpError(
"workdistribute must be nested under teams");
4620LogicalResult DeclareSimdOp::verify() {
4623 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4625 return emitOpError() <<
"must be nested inside a function";
4627 if (getInbranch() && getNotinbranch())
4628 return emitOpError(
"cannot have both 'inbranch' and 'notinbranch'");
4634 const DeclareSimdOperands &clauses) {
4636 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4638 clauses.linearVars, clauses.linearStepVars,
4639 clauses.linearVarTypes, clauses.notinbranch,
4640 clauses.simdlen, clauses.uniformVars);
4657 return mlir::failure();
4658 return mlir::success();
4665 for (
unsigned i = 0; i < uniformVars.size(); ++i) {
4668 p << uniformVars[i] <<
" : " << uniformTypes[i];
4683 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
4684 [&]() -> ParseResult {
return success(); })))
4718 OpAsmParser::Argument &arg = ivArgs.emplace_back();
4719 if (parser.parseArgument(arg))
4723 if (succeeded(parser.parseOptionalColon())) {
4724 if (parser.parseType(arg.type))
4727 arg.type = parser.getBuilder().getIndexType();
4739 OpAsmParser::UnresolvedOperand lb, ub, st;
4740 if (parser.parseOperand(lb) || parser.parseKeyword(
"to") ||
4741 parser.parseOperand(ub) || parser.parseKeyword(
"step") ||
4742 parser.parseOperand(st))
4747 steps.push_back(st);
4755 if (ivArgs.size() != lbs.size())
4757 <<
"mismatch: " << ivArgs.size() <<
" variables but " << lbs.size()
4760 for (
auto &arg : ivArgs) {
4761 lbTypes.push_back(arg.type);
4762 ubTypes.push_back(arg.type);
4763 stepTypes.push_back(arg.type);
4783 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
4786 p << lbs[i] <<
" to " << ubs[i] <<
" step " << steps[i];
4794LogicalResult IteratorOp::verify() {
4795 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().
getType());
4797 return emitOpError() <<
"result must be omp.iterated<entry_ty>";
4799 Block &
b = getRegion().front();
4800 auto yield = llvm::dyn_cast<omp::YieldOp>(
b.getTerminator());
4803 return emitOpError() <<
"region must be terminated by omp.yield";
4805 if (yield.getNumOperands() != 1)
4807 <<
"omp.yield in omp.iterator region must yield exactly one value";
4809 mlir::Type yieldedTy = yield.getOperand(0).getType();
4810 mlir::Type elemTy = iteratedTy.getElementType();
4812 if (yieldedTy != elemTy)
4813 return emitOpError() <<
"omp.iterated element type (" << elemTy
4814 <<
") does not match omp.yield operand type ("
4815 << yieldedTy <<
")";
4820#define GET_ATTRDEF_CLASSES
4821#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4823#define GET_OP_CLASSES
4824#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4826#define GET_TYPEDEF_CLASSES
4827#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes)
Print Linear Clause.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.