27#include "llvm/ADT/ArrayRef.h"
28#include "llvm/ADT/PostOrderIterator.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/STLForwardCompat.h"
31#include "llvm/ADT/SmallString.h"
32#include "llvm/ADT/StringExtras.h"
33#include "llvm/ADT/StringRef.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/ADT/bit.h"
36#include "llvm/Support/InterleavedRange.h"
42#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
43#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
52 return attrs.empty() ?
nullptr : ArrayAttr::get(context, attrs);
66struct MemRefPointerLikeModel
67 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
70 return llvm::cast<MemRefType>(pointer).getElementType();
74struct LLVMPointerPointerLikeModel
75 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76 LLVM::LLVMPointerType> {
101 bool isRegionArgOfOp;
111 assert(isRegionArgOfOp &&
"Must describe a region operand");
114 size_t &getArgIdx() {
115 assert(isRegionArgOfOp &&
"Must describe a region operand");
120 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
124 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
127 bool isLoopOp()
const {
128 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
129 return isa<CanonicalLoopOp>(op);
131 Region *&getParentRegion() {
132 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
135 size_t &getLoopDepth() {
136 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
140 void skipIf(
bool v =
true) { skip = skip || v; }
158 llvm::ReversePostOrderTraversal<Block *> traversal(&r->
getBlocks().front());
161 size_t sequentialIdx = -1;
162 bool isOnlyContainerOp =
true;
163 for (
Block *
b : traversal) {
165 if (&op == o && !found) {
169 if (op.getNumRegions()) {
172 isOnlyContainerOp =
false;
174 if (found && !isOnlyContainerOp)
179 Component &containerOpInRegion = components.emplace_back();
180 containerOpInRegion.isRegionArgOfOp =
false;
181 containerOpInRegion.isUnique = isOnlyContainerOp;
182 containerOpInRegion.getContainerOp() = o;
183 containerOpInRegion.getOpPos() = sequentialIdx;
184 containerOpInRegion.getParentRegion() = r;
189 Component ®ionArgOfOperation = components.emplace_back();
190 regionArgOfOperation.isRegionArgOfOp =
true;
191 regionArgOfOperation.isUnique =
true;
192 regionArgOfOperation.getArgIdx() = 0;
193 regionArgOfOperation.getOwnerOp() = parent;
205 for (
auto [idx, region] : llvm::enumerate(o->
getRegions())) {
209 llvm_unreachable(
"Region not child of its parent operation");
211 regionArgOfOperation.isUnique =
false;
212 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
220 for (Component &c : components)
221 c.skipIf(c.isRegionArgOfOp && c.isUnique);
224 size_t numSurroundingLoops = 0;
225 for (Component &c : llvm::reverse(components)) {
230 if (c.isRegionArgOfOp) {
231 numSurroundingLoops = 0;
238 numSurroundingLoops = 0;
240 c.getLoopDepth() = numSurroundingLoops;
243 if (isa<CanonicalLoopOp>(c.getContainerOp()))
244 numSurroundingLoops += 1;
249 bool isLoopNest =
false;
250 for (Component &c : components) {
251 if (c.skip || c.isRegionArgOfOp)
254 if (!isLoopNest && c.getLoopDepth() >= 1) {
257 }
else if (isLoopNest) {
259 c.skipIf(c.isUnique);
263 if (c.getLoopDepth() == 0)
270 for (Component &c : components)
271 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
272 !isa<CanonicalLoopOp>(c.getContainerOp()));
276 bool newRegion =
true;
277 for (Component &c : llvm::reverse(components)) {
278 c.skipIf(newRegion && c.isUnique);
285 if (!c.isRegionArgOfOp && c.getContainerOp())
291 llvm::raw_svector_ostream NameOS(Name);
292 for (
auto &c : llvm::reverse(components)) {
296 if (c.isRegionArgOfOp)
297 NameOS <<
"_r" << c.getArgIdx();
298 else if (c.getLoopDepth() >= 1)
299 NameOS <<
"_d" << c.getLoopDepth();
301 NameOS <<
"_s" << c.getOpPos();
304 return NameOS.str().str();
307void OpenMPDialect::initialize() {
310#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
313#define GET_ATTRDEF_LIST
314#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
317#define GET_TYPEDEF_LIST
318#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
321 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
323 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
324 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
329 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
335 mlir::LLVM::GlobalOp::attachInterface<
338 mlir::LLVM::LLVMFuncOp::attachInterface<
341 mlir::func::FuncOp::attachInterface<
367 allocatorVars.push_back(operand);
368 allocatorTypes.push_back(type);
374 allocateVars.push_back(operand);
375 allocateTypes.push_back(type);
386 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
387 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
388 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
389 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
397template <
typename ClauseAttr>
399 using ClauseT =
decltype(std::declval<ClauseAttr>().getValue());
404 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
405 attr = ClauseAttr::get(parser.
getContext(), *enumValue);
408 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
411template <
typename ClauseAttr>
413 p << stringifyEnum(attr.getValue());
438 std::optional<omp::LinearModifier> linearModifier;
440 linearModifier = omp::LinearModifier::val;
442 linearModifier = omp::LinearModifier::ref;
444 linearModifier = omp::LinearModifier::uval;
447 bool hasLinearModifierParens = linearModifier.has_value();
448 if (hasLinearModifierParens && parser.
parseLParen())
456 if (hasLinearModifierParens && parser.
parseRParen())
459 linearVars.push_back(var);
460 linearTypes.push_back(type);
461 linearStepVars.push_back(stepVar);
462 linearStepTypes.push_back(stepType);
463 if (linearModifier) {
465 omp::LinearModifierAttr::get(parser.
getContext(), *linearModifier));
467 modifiers.push_back(UnitAttr::get(parser.
getContext()));
473 linearModifiers = ArrayAttr::get(parser.
getContext(), modifiers);
482 size_t linearVarsSize = linearVars.size();
483 for (
unsigned i = 0; i < linearVarsSize; ++i) {
487 Attribute modAttr = linearModifiers ? linearModifiers[i] :
nullptr;
488 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) :
nullptr;
490 p << omp::stringifyLinearModifier(mod.getValue()) <<
"(";
492 p << linearVars[i] <<
" : " << linearTypes[i];
493 p <<
" = " << linearStepVars[i] <<
" : " << stepVarTypes[i];
509 if (!linearModifiers)
511 if (linearModifiers->size() != linearVars.size())
513 <<
"expected as many linear modifiers as linear variables";
514 if (!isDeclareSimd) {
515 for (
Attribute attr : *linearModifiers) {
518 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
521 omp::LinearModifier mod = modAttr.getValue();
522 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
524 <<
"linear modifier '" << omp::stringifyLinearModifier(mod)
525 <<
"' may only be specified on a declare simd directive";
540 for (
const auto &it : nontemporalVars)
541 if (!nontemporalItems.insert(it).second)
542 return op->
emitOpError() <<
"nontemporal variable used more than once";
551 std::optional<ArrayAttr> alignments,
554 if (!alignedVars.empty()) {
555 if (!alignments || alignments->size() != alignedVars.size())
557 <<
"expected as many alignment values as aligned variables";
560 return op->
emitOpError() <<
"unexpected alignment values attribute";
566 for (
auto it : alignedVars)
567 if (!alignedItems.insert(it).second)
568 return op->
emitOpError() <<
"aligned variable used more than once";
574 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
575 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
576 if (intAttr.getValue().sle(0))
577 return op->
emitOpError() <<
"alignment should be greater than 0";
579 return op->
emitOpError() <<
"expected integer alignment";
596 if (parser.parseOperand(alignedVars.emplace_back()) ||
597 parser.parseColonType(alignedTypes.emplace_back()) ||
598 parser.parseArrow() ||
599 parser.parseAttribute(alignmentVec.emplace_back())) {
606 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
613 std::optional<ArrayAttr> alignments) {
614 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
617 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
618 p <<
" -> " << (*alignments)[i];
629 if (modifiers.size() > 2)
631 for (
const auto &mod : modifiers) {
634 auto symbol = symbolizeScheduleModifier(mod);
637 <<
" unknown modifier type: " << mod;
642 if (modifiers.size() == 1) {
643 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
644 modifiers.push_back(modifiers[0]);
645 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
647 }
else if (modifiers.size() == 2) {
650 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
651 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
653 <<
" incorrect modifier order";
669 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
670 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
675 std::optional<mlir::omp::ClauseScheduleKind> schedule =
676 symbolizeClauseScheduleKind(keyword);
680 scheduleAttr = ClauseScheduleKindAttr::get(parser.
getContext(), *schedule);
682 case ClauseScheduleKind::Static:
683 case ClauseScheduleKind::Dynamic:
684 case ClauseScheduleKind::Guided:
690 chunkSize = std::nullopt;
693 case ClauseScheduleKind::Auto:
694 case ClauseScheduleKind::Runtime:
695 case ClauseScheduleKind::Distribute:
696 chunkSize = std::nullopt;
705 modifiers.push_back(mod);
711 if (!modifiers.empty()) {
713 if (std::optional<ScheduleModifier> mod =
714 symbolizeScheduleModifier(modifiers[0])) {
715 scheduleMod = ScheduleModifierAttr::get(parser.
getContext(), *mod);
717 return parser.
emitError(loc,
"invalid schedule modifier");
720 if (modifiers.size() > 1) {
721 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
731 ClauseScheduleKindAttr scheduleKind,
732 ScheduleModifierAttr scheduleMod,
733 UnitAttr scheduleSimd,
Value scheduleChunk,
734 Type scheduleChunkType) {
735 p << stringifyClauseScheduleKind(scheduleKind.getValue());
737 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
739 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
751 ClauseOrderKindAttr &order,
752 OrderModifierAttr &orderMod) {
757 if (std::optional<OrderModifier> enumValue =
758 symbolizeOrderModifier(enumStr)) {
759 orderMod = OrderModifierAttr::get(parser.
getContext(), *enumValue);
766 if (std::optional<ClauseOrderKind> enumValue =
767 symbolizeClauseOrderKind(enumStr)) {
768 order = ClauseOrderKindAttr::get(parser.
getContext(), *enumValue);
771 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
775 ClauseOrderKindAttr order,
776 OrderModifierAttr orderMod) {
778 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
780 p << stringifyClauseOrderKind(order.getValue());
783template <
typename ClauseTypeAttr,
typename ClauseType>
786 std::optional<OpAsmParser::UnresolvedOperand> &operand,
788 std::optional<ClauseType> (*symbolizeClause)(StringRef),
789 StringRef clauseName) {
792 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
793 prescriptiveness = ClauseTypeAttr::get(parser.
getContext(), *enumValue);
798 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
808 <<
"expected " << clauseName <<
" operand";
811 if (operand.has_value()) {
819template <
typename ClauseTypeAttr,
typename ClauseType>
822 ClauseTypeAttr prescriptiveness,
Value operand,
824 StringRef (*stringifyClauseType)(ClauseType)) {
826 if (prescriptiveness)
827 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
830 p << operand <<
": " << operandType;
840 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
841 Type &grainsizeType) {
843 parser, grainsizeMod, grainsize, grainsizeType,
844 &symbolizeClauseGrainsizeType,
"grainsize");
848 ClauseGrainsizeTypeAttr grainsizeMod,
851 p, op, grainsizeMod, grainsize, grainsizeType,
852 &stringifyClauseGrainsizeType);
862 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
863 Type &numTasksType) {
865 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
870 ClauseNumTasksTypeAttr numTasksMod,
873 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
882 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
883 SmallVectorImpl<Type> &types;
884 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
885 SmallVectorImpl<Type> &types)
886 : vars(vars), types(types) {}
888struct PrivateParseArgs {
889 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
890 llvm::SmallVectorImpl<Type> &types;
892 UnitAttr &needsBarrier;
894 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
895 SmallVectorImpl<Type> &types,
ArrayAttr &syms,
896 UnitAttr &needsBarrier,
898 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
899 mapIndices(mapIndices) {}
902struct ReductionParseArgs {
903 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
904 SmallVectorImpl<Type> &types;
907 ReductionModifierAttr *modifier;
908 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
910 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
911 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
914struct AllRegionParseArgs {
915 std::optional<MapParseArgs> hasDeviceAddrArgs;
916 std::optional<MapParseArgs> hostEvalArgs;
917 std::optional<ReductionParseArgs> inReductionArgs;
918 std::optional<MapParseArgs> mapArgs;
919 std::optional<PrivateParseArgs> privateArgs;
920 std::optional<ReductionParseArgs> reductionArgs;
921 std::optional<ReductionParseArgs> taskReductionArgs;
922 std::optional<MapParseArgs> useDeviceAddrArgs;
923 std::optional<MapParseArgs> useDevicePtrArgs;
928 return "private_barrier";
938 ReductionModifierAttr *modifier =
nullptr,
939 UnitAttr *needsBarrier =
nullptr) {
943 unsigned regionArgOffset = regionPrivateArgs.size();
953 std::optional<ReductionModifier> enumValue =
954 symbolizeReductionModifier(enumStr);
955 if (!enumValue.has_value())
957 *modifier = ReductionModifierAttr::get(parser.
getContext(), *enumValue);
964 isByRefVec.push_back(
965 parser.parseOptionalKeyword(
"byref").succeeded());
967 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
970 if (parser.parseOperand(operands.emplace_back()) ||
971 parser.parseArrow() ||
972 parser.parseArgument(regionPrivateArgs.emplace_back()))
976 if (parser.parseOptionalLSquare().succeeded()) {
977 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
978 parser.parseInteger(mapIndicesVec.emplace_back()) ||
979 parser.parseRSquare())
982 mapIndicesVec.push_back(-1);
994 if (parser.parseType(types.emplace_back()))
1001 if (operands.size() != types.size())
1010 *needsBarrier = mlir::UnitAttr::get(parser.
getContext());
1013 auto *argsBegin = regionPrivateArgs.begin();
1015 argsBegin + regionArgOffset + types.size());
1016 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1022 *symbols = ArrayAttr::get(parser.
getContext(), symbolAttrs);
1025 if (!mapIndicesVec.empty())
1038 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1053 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1059 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1060 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
1061 nullptr, &privateArgs->needsBarrier)))
1070 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1075 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1076 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
1077 reductionArgs->modifier)))
1084 AllRegionParseArgs args) {
1088 args.hasDeviceAddrArgs)))
1090 <<
"invalid `has_device_addr` format";
1093 args.hostEvalArgs)))
1095 <<
"invalid `host_eval` format";
1098 args.inReductionArgs)))
1100 <<
"invalid `in_reduction` format";
1105 <<
"invalid `map_entries` format";
1110 <<
"invalid `private` format";
1113 args.reductionArgs)))
1115 <<
"invalid `reduction` format";
1118 args.taskReductionArgs)))
1120 <<
"invalid `task_reduction` format";
1123 args.useDeviceAddrArgs)))
1125 <<
"invalid `use_device_addr` format";
1128 args.useDevicePtrArgs)))
1130 <<
"invalid `use_device_addr` format";
1132 return parser.
parseRegion(region, entryBlockArgs);
1151 AllRegionParseArgs args;
1152 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1153 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1154 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1155 inReductionByref, inReductionSyms);
1156 args.mapArgs.emplace(mapVars, mapTypes);
1157 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1158 privateNeedsBarrier, &privateMaps);
1169 UnitAttr &privateNeedsBarrier) {
1170 AllRegionParseArgs args;
1171 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1172 inReductionByref, inReductionSyms);
1173 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1174 privateNeedsBarrier);
1185 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1189 AllRegionParseArgs args;
1190 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1191 inReductionByref, inReductionSyms);
1192 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1193 privateNeedsBarrier);
1194 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1195 reductionSyms, &reductionMod);
1203 UnitAttr &privateNeedsBarrier) {
1204 AllRegionParseArgs args;
1205 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1206 privateNeedsBarrier);
1214 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1218 AllRegionParseArgs args;
1219 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1220 privateNeedsBarrier);
1221 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1222 reductionSyms, &reductionMod);
1231 AllRegionParseArgs args;
1232 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1233 taskReductionByref, taskReductionSyms);
1243 AllRegionParseArgs args;
1244 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1245 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1254struct MapPrintArgs {
1259struct PrivatePrintArgs {
1263 UnitAttr needsBarrier;
1267 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1268 mapIndices(mapIndices) {}
1270struct ReductionPrintArgs {
1275 ReductionModifierAttr modifier;
1277 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
1278 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1280struct AllRegionPrintArgs {
1281 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1282 std::optional<MapPrintArgs> hostEvalArgs;
1283 std::optional<ReductionPrintArgs> inReductionArgs;
1284 std::optional<MapPrintArgs> mapArgs;
1285 std::optional<PrivatePrintArgs> privateArgs;
1286 std::optional<ReductionPrintArgs> reductionArgs;
1287 std::optional<ReductionPrintArgs> taskReductionArgs;
1288 std::optional<MapPrintArgs> useDeviceAddrArgs;
1289 std::optional<MapPrintArgs> useDevicePtrArgs;
1298 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
1299 if (argsSubrange.empty())
1302 p << clauseName <<
"(";
1305 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1309 symbols = ArrayAttr::get(ctx, values);
1322 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1323 mapIndices.asArrayRef(),
1324 byref.asArrayRef()),
1326 auto [op, arg, sym, map, isByRef] = t;
1332 p << op <<
" -> " << arg;
1335 p <<
" [map_idx=" << map <<
"]";
1338 llvm::interleaveComma(types, p);
1346 StringRef clauseName,
ValueRange argsSubrange,
1347 std::optional<MapPrintArgs> mapArgs) {
1354 StringRef clauseName,
ValueRange argsSubrange,
1355 std::optional<PrivatePrintArgs> privateArgs) {
1358 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1359 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1360 nullptr, privateArgs->needsBarrier);
1366 std::optional<ReductionPrintArgs> reductionArgs) {
1369 reductionArgs->vars, reductionArgs->types,
1370 reductionArgs->syms,
nullptr,
1371 reductionArgs->byref, reductionArgs->modifier);
1375 const AllRegionPrintArgs &args) {
1376 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1380 iface.getHasDeviceAddrBlockArgs(),
1381 args.hasDeviceAddrArgs);
1385 args.inReductionArgs);
1391 args.reductionArgs);
1393 iface.getTaskReductionBlockArgs(),
1394 args.taskReductionArgs);
1396 iface.getUseDeviceAddrBlockArgs(),
1397 args.useDeviceAddrArgs);
1399 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1415 AllRegionPrintArgs args;
1416 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1417 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1418 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1419 inReductionByref, inReductionSyms);
1420 args.mapArgs.emplace(mapVars, mapTypes);
1421 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1422 privateNeedsBarrier, privateMaps);
1430 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1431 AllRegionPrintArgs args;
1432 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1433 inReductionByref, inReductionSyms);
1434 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1435 privateNeedsBarrier,
1444 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1445 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1448 AllRegionPrintArgs args;
1449 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1450 inReductionByref, inReductionSyms);
1451 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1452 privateNeedsBarrier,
1454 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1455 reductionSyms, reductionMod);
1462 UnitAttr privateNeedsBarrier) {
1463 AllRegionPrintArgs args;
1464 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1465 privateNeedsBarrier,
1473 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1476 AllRegionPrintArgs args;
1477 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1478 privateNeedsBarrier,
1480 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1481 reductionSyms, reductionMod);
1491 AllRegionPrintArgs args;
1492 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1493 taskReductionByref, taskReductionSyms);
1503 AllRegionPrintArgs args;
1504 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1505 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1509template <
typename ParsePrefixFn>
1518 if (failed(parsePrefix()))
1526 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1527 iteratedVars.push_back(v);
1528 iteratedTypes.push_back(ty);
1530 plainVars.push_back(v);
1531 plainTypes.push_back(ty);
1537template <
typename Pr
intPrefixFn>
1541 PrintPrefixFn &&printPrefixForPlain,
1542 PrintPrefixFn &&printPrefixForIterated) {
1549 p << v <<
" : " << t;
1553 for (
unsigned i = 0; i < iteratedVars.size(); ++i)
1554 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1555 for (
unsigned i = 0; i < plainVars.size(); ++i)
1556 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1564 if (!reductionVars.empty()) {
1565 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1567 <<
"expected as many reduction symbol references "
1568 "as reduction variables";
1569 if (reductionByref && reductionByref->size() != reductionVars.size())
1570 return op->
emitError() <<
"expected as many reduction variable by "
1571 "reference attributes as reduction variables";
1574 return op->
emitOpError() <<
"unexpected reduction symbol references";
1581 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1582 Value accum = std::get<0>(args);
1584 if (!accumulators.insert(accum).second)
1585 return op->
emitOpError() <<
"accumulator variable used more than once";
1588 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1592 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1593 <<
" to point to a reduction declaration";
1595 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1597 <<
"expected accumulator (" << varType
1598 <<
") to be the same type as reduction declaration ("
1599 << decl.getAccumulatorType() <<
")";
1618 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1619 parser.parseArrow() ||
1620 parser.parseAttribute(symsVec.emplace_back()) ||
1621 parser.parseColonType(copyprivateTypes.emplace_back()))
1627 copyprivateSyms = ArrayAttr::get(parser.
getContext(), syms);
1635 std::optional<ArrayAttr> copyprivateSyms) {
1636 if (!copyprivateSyms.has_value())
1638 llvm::interleaveComma(
1639 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1640 [&](
const auto &args) {
1641 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1642 << std::get<2>(args);
1649 std::optional<ArrayAttr> copyprivateSyms) {
1650 size_t copyprivateSymsSize =
1651 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1652 if (copyprivateSymsSize != copyprivateVars.size())
1653 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1654 << copyprivateVars.size()
1655 <<
") and functions (= " << copyprivateSymsSize
1656 <<
"), both must be equal";
1657 if (!copyprivateSyms.has_value())
1660 for (
auto copyprivateVarAndSym :
1661 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1663 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1664 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1666 if (mlir::func::FuncOp mlirFuncOp =
1669 funcOp = mlirFuncOp;
1670 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1673 funcOp = llvmFuncOp;
1675 auto getNumArguments = [&] {
1676 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1679 auto getArgumentType = [&](
unsigned i) {
1680 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1685 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1686 <<
" to point to a copy function";
1688 if (getNumArguments() != 2)
1690 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1692 Type argTy = getArgumentType(0);
1693 if (argTy != getArgumentType(1))
1694 return op->
emitOpError() <<
"expected copy function " << symbolRef
1695 <<
" arguments to have the same type";
1697 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1698 if (argTy != varType)
1700 <<
"expected copy function arguments' type (" << argTy
1701 <<
") to be the same as copyprivate variable's type (" << varType
1726 OpAsmParser::UnresolvedOperand operand;
1728 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1729 parser.parseOperand(operand) || parser.parseColonType(ty))
1731 std::optional<ClauseTaskDepend> keywordDepend =
1732 symbolizeClauseTaskDepend(keyword);
1736 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1737 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1738 iteratedVars.push_back(operand);
1739 iteratedTypes.push_back(ty);
1740 iterKindsVec.push_back(kindAttr);
1742 dependVars.push_back(operand);
1743 dependTypes.push_back(ty);
1744 kindsVec.push_back(kindAttr);
1750 dependKinds = ArrayAttr::get(parser.
getContext(), kinds);
1752 iteratedKinds = ArrayAttr::get(parser.
getContext(), iterKinds);
1759 std::optional<ArrayAttr> dependKinds,
1762 std::optional<ArrayAttr> iteratedKinds) {
1765 std::optional<ArrayAttr> kinds) {
1766 for (
unsigned i = 0, e = vars.size(); i < e; ++i) {
1769 p << stringifyClauseTaskDepend(
1770 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1772 <<
" -> " << vars[i] <<
" : " << types[i];
1776 printEntries(dependVars, dependTypes, dependKinds);
1777 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1782 std::optional<ArrayAttr> dependKinds,
1784 std::optional<ArrayAttr> iteratedKinds,
1786 if (!dependVars.empty()) {
1787 if (!dependKinds || dependKinds->size() != dependVars.size())
1788 return op->
emitOpError() <<
"expected as many depend values"
1789 " as depend variables";
1791 if (dependKinds && !dependKinds->empty())
1792 return op->
emitOpError() <<
"unexpected depend values";
1795 if (!iteratedVars.empty()) {
1796 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1797 return op->
emitOpError() <<
"expected as many depend iterated values"
1798 " as depend iterated variables";
1800 if (iteratedKinds && !iteratedKinds->empty())
1801 return op->
emitOpError() <<
"unexpected depend iterated values";
1816 IntegerAttr &hintAttr) {
1817 StringRef hintKeyword;
1823 auto parseKeyword = [&]() -> ParseResult {
1826 if (hintKeyword ==
"uncontended")
1828 else if (hintKeyword ==
"contended")
1830 else if (hintKeyword ==
"nonspeculative")
1832 else if (hintKeyword ==
"speculative")
1836 << hintKeyword <<
" is not a valid hint";
1847 IntegerAttr hintAttr) {
1848 int64_t hint = hintAttr.getInt();
1856 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1858 bool uncontended = bitn(hint, 0);
1859 bool contended = bitn(hint, 1);
1860 bool nonspeculative = bitn(hint, 2);
1861 bool speculative = bitn(hint, 3);
1865 hints.push_back(
"uncontended");
1867 hints.push_back(
"contended");
1869 hints.push_back(
"nonspeculative");
1871 hints.push_back(
"speculative");
1873 llvm::interleaveComma(hints, p);
1880 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1882 bool uncontended = bitn(hint, 0);
1883 bool contended = bitn(hint, 1);
1884 bool nonspeculative = bitn(hint, 2);
1885 bool speculative = bitn(hint, 3);
1887 if (uncontended && contended)
1888 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1889 "omp_sync_hint_contended cannot be combined";
1890 if (nonspeculative && speculative)
1891 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1892 "omp_sync_hint_speculative cannot be combined.";
1903 return (value & flag) == flag;
1911static ParseResult parseMapClause(
OpAsmParser &parser,
1912 ClauseMapFlagsAttr &mapType) {
1913 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1916 auto parseTypeAndMod = [&]() -> ParseResult {
1917 StringRef mapTypeMod;
1921 if (mapTypeMod ==
"always")
1922 mapTypeBits |= ClauseMapFlags::always;
1924 if (mapTypeMod ==
"implicit")
1925 mapTypeBits |= ClauseMapFlags::implicit;
1927 if (mapTypeMod ==
"ompx_hold")
1928 mapTypeBits |= ClauseMapFlags::ompx_hold;
1930 if (mapTypeMod ==
"close")
1931 mapTypeBits |= ClauseMapFlags::close;
1933 if (mapTypeMod ==
"present")
1934 mapTypeBits |= ClauseMapFlags::present;
1936 if (mapTypeMod ==
"to")
1937 mapTypeBits |= ClauseMapFlags::to;
1939 if (mapTypeMod ==
"from")
1940 mapTypeBits |= ClauseMapFlags::from;
1942 if (mapTypeMod ==
"tofrom")
1943 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1945 if (mapTypeMod ==
"delete")
1946 mapTypeBits |= ClauseMapFlags::del;
1948 if (mapTypeMod ==
"storage")
1949 mapTypeBits |= ClauseMapFlags::storage;
1951 if (mapTypeMod ==
"return_param")
1952 mapTypeBits |= ClauseMapFlags::return_param;
1954 if (mapTypeMod ==
"private")
1955 mapTypeBits |= ClauseMapFlags::priv;
1957 if (mapTypeMod ==
"literal")
1958 mapTypeBits |= ClauseMapFlags::literal;
1960 if (mapTypeMod ==
"attach")
1961 mapTypeBits |= ClauseMapFlags::attach;
1963 if (mapTypeMod ==
"attach_always")
1964 mapTypeBits |= ClauseMapFlags::attach_always;
1966 if (mapTypeMod ==
"attach_never")
1967 mapTypeBits |= ClauseMapFlags::attach_never;
1969 if (mapTypeMod ==
"attach_auto")
1970 mapTypeBits |= ClauseMapFlags::attach_auto;
1972 if (mapTypeMod ==
"ref_ptr")
1973 mapTypeBits |= ClauseMapFlags::ref_ptr;
1975 if (mapTypeMod ==
"ref_ptee")
1976 mapTypeBits |= ClauseMapFlags::ref_ptee;
1978 if (mapTypeMod ==
"ref_ptr_ptee")
1979 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1981 if (mapTypeMod ==
"is_device_ptr")
1982 mapTypeBits |= ClauseMapFlags::is_device_ptr;
1999 ClauseMapFlagsAttr mapType) {
2001 ClauseMapFlags mapFlags = mapType.getValue();
2006 mapTypeStrs.push_back(
"always");
2008 mapTypeStrs.push_back(
"implicit");
2010 mapTypeStrs.push_back(
"ompx_hold");
2012 mapTypeStrs.push_back(
"close");
2014 mapTypeStrs.push_back(
"present");
2023 mapTypeStrs.push_back(
"tofrom");
2025 mapTypeStrs.push_back(
"from");
2027 mapTypeStrs.push_back(
"to");
2030 mapTypeStrs.push_back(
"delete");
2032 mapTypeStrs.push_back(
"return_param");
2034 mapTypeStrs.push_back(
"storage");
2036 mapTypeStrs.push_back(
"private");
2038 mapTypeStrs.push_back(
"literal");
2040 mapTypeStrs.push_back(
"attach");
2042 mapTypeStrs.push_back(
"attach_always");
2044 mapTypeStrs.push_back(
"attach_never");
2046 mapTypeStrs.push_back(
"attach_auto");
2048 mapTypeStrs.push_back(
"ref_ptr");
2050 mapTypeStrs.push_back(
"ref_ptee");
2052 mapTypeStrs.push_back(
"ref_ptr_ptee");
2054 mapTypeStrs.push_back(
"is_device_ptr");
2055 if (mapFlags == ClauseMapFlags::none)
2056 mapTypeStrs.push_back(
"none");
2058 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2059 p << mapTypeStrs[i];
2060 if (i + 1 < mapTypeStrs.size()) {
2066static ParseResult parseMembersIndex(
OpAsmParser &parser,
2070 auto parseIndices = [&]() -> ParseResult {
2075 APInt(64, value,
false)));
2089 memberIdxs.push_back(ArrayAttr::get(parser.
getContext(), values));
2093 if (!memberIdxs.empty())
2094 membersIdx = ArrayAttr::get(parser.
getContext(), memberIdxs);
2104 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
2106 auto memberIdx = cast<ArrayAttr>(v);
2107 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
2108 p << cast<IntegerAttr>(v2).getInt();
2115 VariableCaptureKindAttr mapCaptureType) {
2116 std::string typeCapStr;
2117 llvm::raw_string_ostream typeCap(typeCapStr);
2118 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2120 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2121 typeCap <<
"ByCopy";
2122 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2123 typeCap <<
"VLAType";
2124 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2130 VariableCaptureKindAttr &mapCaptureType) {
2131 StringRef mapCaptureKey;
2135 if (mapCaptureKey ==
"This")
2136 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2137 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
2138 if (mapCaptureKey ==
"ByRef")
2139 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2140 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
2141 if (mapCaptureKey ==
"ByCopy")
2142 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2143 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2144 if (mapCaptureKey ==
"VLAType")
2145 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2146 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
2155 for (
auto mapOp : mapVars) {
2156 if (!mapOp.getDefiningOp())
2159 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2160 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2163 bool from =
mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2166 bool always =
mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2167 bool close =
mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2168 bool implicit =
mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2170 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2172 "to, from, tofrom and alloc map types are permitted");
2174 if (isa<TargetEnterDataOp>(op) && (from || del))
2175 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
2177 if (isa<TargetExitDataOp>(op) && to)
2179 "from, release and delete map types are permitted");
2181 if (isa<TargetUpdateOp>(op)) {
2184 "at least one of to or from map types must be "
2185 "specified, other map types are not permitted");
2190 "at least one of to or from map types must be "
2191 "specified, other map types are not permitted");
2194 auto updateVar = mapInfoOp.getVarPtr();
2196 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2197 (from && updateToVars.contains(updateVar))) {
2200 "either to or from map types can be specified, not both");
2203 if (always || close || implicit) {
2206 "present, mapper and iterator map type modifiers are permitted");
2209 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2211 }
else if (!isa<DeclareMapperInfoOp>(op)) {
2213 "map argument is not a map entry operation");
2221 std::optional<DenseI64ArrayAttr> privateMapIndices =
2222 targetOp.getPrivateMapsAttr();
2225 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2230 if (privateMapIndices.value().size() !=
2231 static_cast<int64_t>(privateVars.size()))
2232 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
2233 "`private_maps` attribute mismatch");
2243 StringRef clauseName,
2245 for (
Value var : vars)
2246 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2248 <<
"'" << clauseName
2249 <<
"' arguments must be defined by 'omp.map.info' ops";
2253LogicalResult MapInfoOp::verify() {
2254 if (getMapperId() &&
2256 *
this, getMapperIdAttr())) {
2271 const TargetDataOperands &clauses) {
2272 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2273 clauses.mapVars, clauses.useDeviceAddrVars,
2274 clauses.useDevicePtrVars);
2277LogicalResult TargetDataOp::verify() {
2278 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2279 getUseDeviceAddrVars().empty()) {
2280 return ::emitError(this->getLoc(),
2281 "At least one of map, use_device_ptr_vars, or "
2282 "use_device_addr_vars operand must be present");
2286 getUseDevicePtrVars())))
2290 getUseDeviceAddrVars())))
2300void TargetEnterDataOp::build(
2304 TargetEnterDataOp::build(
2306 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2307 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2311LogicalResult TargetEnterDataOp::verify() {
2312 LogicalResult verifyDependVars =
2314 getDependIteratedKinds(), getDependIterated());
2315 return failed(verifyDependVars) ? verifyDependVars
2326 TargetExitDataOp::build(
2328 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2329 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2333LogicalResult TargetExitDataOp::verify() {
2334 LogicalResult verifyDependVars =
2336 getDependIteratedKinds(), getDependIterated());
2337 return failed(verifyDependVars) ? verifyDependVars
2348 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
2351 clauses.dependIterated, clauses.device, clauses.ifExpr,
2352 clauses.mapVars, clauses.nowait);
2355LogicalResult TargetUpdateOp::verify() {
2356 LogicalResult verifyDependVars =
2358 getDependIteratedKinds(), getDependIterated());
2359 return failed(verifyDependVars) ? verifyDependVars
2368 const TargetOperands &clauses) {
2373 builder, state, {}, {}, clauses.bare,
2374 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2375 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2376 clauses.device, clauses.hasDeviceAddrVars, clauses.hostEvalVars,
2379 nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2380 clauses.nowait, clauses.privateVars,
2381 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2382 clauses.threadLimitVars,
2386LogicalResult TargetOp::verify() {
2388 getDependIteratedKinds(),
2389 getDependIterated())))
2393 getHasDeviceAddrVars())))
2402LogicalResult TargetOp::verifyRegions() {
2403 auto teamsOps = getOps<TeamsOp>();
2404 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2405 return emitError(
"target containing multiple 'omp.teams' nested ops");
2408 Operation *capturedOp = getInnermostCapturedOmpOp();
2409 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2410 for (
Value hostEvalArg :
2411 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2413 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
2415 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2416 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2417 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2420 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
2421 "and 'thread_limit' in 'omp.teams'";
2423 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
2424 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2425 parallelOp->isAncestor(capturedOp) &&
2426 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2430 <<
"host_eval argument only legal as 'num_threads' in "
2431 "'omp.parallel' when representing target SPMD";
2433 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2434 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2435 loopNestOp.getOperation() == capturedOp &&
2436 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2437 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2438 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2441 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2442 "and steps in 'omp.loop_nest' when trip count "
2443 "must be evaluated in the host";
2446 return emitOpError() <<
"host_eval argument illegal use in '"
2447 << user->getName() <<
"' operation";
2456 assert(rootOp &&
"expected valid operation");
2473 bool isOmpDialect = op->
getDialect() == ompDialect;
2475 if (!isOmpDialect || !hasRegions)
2482 if (checkSingleMandatoryExec) {
2487 if (successor->isReachable(parentBlock))
2490 for (
Block &block : *parentRegion)
2492 !domInfo.
dominates(parentBlock, &block))
2499 if (&sibling != op && !siblingAllowedFn(&sibling))
2512Operation *TargetOp::getInnermostCapturedOmpOp() {
2513 auto *ompDialect =
getContext()->getLoadedDialect<omp::OpenMPDialect>();
2525 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2528 memOp.getEffects(effects);
2529 return !llvm::any_of(
2531 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2532 isa<SideEffects::AutomaticAllocationScopeResource>(
2542 WsloopOp *wsLoopOp) {
2544 if (!teamsOp.getNumTeamsUpperVars().empty())
2548 if (teamsOp.getNumReductionVars())
2550 if (wsLoopOp->getNumReductionVars())
2554 OffloadModuleInterface offloadMod =
2558 auto ompFlags = offloadMod.getFlags();
2561 return ompFlags.getAssumeTeamsOversubscription() &&
2562 ompFlags.getAssumeThreadsOversubscription();
2565TargetRegionFlags TargetOp::getKernelExecFlags(
Operation *capturedOp) {
2570 assert((!capturedOp ||
2571 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2572 "unexpected captured op");
2575 if (!isa_and_present<LoopNestOp>(capturedOp))
2576 return TargetRegionFlags::generic;
2580 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2581 assert(!loopWrappers.empty());
2583 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2584 if (isa<SimdOp>(innermostWrapper))
2585 innermostWrapper = std::next(innermostWrapper);
2587 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2588 if (numWrappers != 1 && numWrappers != 2)
2589 return TargetRegionFlags::generic;
2592 if (numWrappers == 2) {
2593 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2595 return TargetRegionFlags::generic;
2597 innermostWrapper = std::next(innermostWrapper);
2598 if (!isa<DistributeOp>(innermostWrapper))
2599 return TargetRegionFlags::generic;
2602 if (!isa_and_present<ParallelOp>(parallelOp))
2603 return TargetRegionFlags::generic;
2605 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->
getParentOp());
2607 return TargetRegionFlags::generic;
2609 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2610 TargetRegionFlags
result =
2611 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2618 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2620 if (!isa_and_present<TeamsOp>(teamsOp))
2621 return TargetRegionFlags::generic;
2623 if (teamsOp->
getParentOp() != targetOp.getOperation())
2624 return TargetRegionFlags::generic;
2626 if (isa<LoopOp>(innermostWrapper))
2627 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2637 Dialect *ompDialect = targetOp->getDialect();
2641 return sibling && (ompDialect != sibling->
getDialect() ||
2645 TargetRegionFlags
result =
2646 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2651 while (nestedCapture->
getParentOp() != capturedOp)
2654 return isa<ParallelOp>(nestedCapture) ?
result | TargetRegionFlags::spmd
2658 else if (isa<WsloopOp>(innermostWrapper)) {
2660 if (!isa_and_present<ParallelOp>(parallelOp))
2661 return TargetRegionFlags::generic;
2663 if (parallelOp->
getParentOp() == targetOp.getOperation())
2664 return TargetRegionFlags::spmd;
2667 return TargetRegionFlags::generic;
2676 ParallelOp::build(builder, state,
ValueRange(),
2688 const ParallelOperands &clauses) {
2690 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2691 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2693 clauses.privateNeedsBarrier, clauses.procBindKind,
2694 clauses.reductionMod, clauses.reductionVars,
2699template <
typename OpType>
2701 auto privateVars = op.getPrivateVars();
2702 auto privateSyms = op.getPrivateSymsAttr();
2704 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2707 auto numPrivateVars = privateVars.size();
2708 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2710 if (numPrivateVars != numPrivateSyms)
2711 return op.emitError() <<
"inconsistent number of private variables and "
2712 "privatizer op symbols, private vars: "
2714 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2716 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2717 Type varType = std::get<0>(privateVarInfo).getType();
2718 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2719 PrivateClauseOp privatizerOp =
2722 if (privatizerOp ==
nullptr)
2723 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2724 << privateSym <<
"'";
2726 Type privatizerType = privatizerOp.getArgType();
2728 if (privatizerType && (varType != privatizerType))
2729 return op.emitError()
2730 <<
"type mismatch between a "
2731 << (privatizerOp.getDataSharingType() ==
2732 DataSharingClauseType::Private
2735 <<
" variable and its privatizer op, var type: " << varType
2736 <<
" vs. privatizer op type: " << privatizerType;
2742LogicalResult ParallelOp::verify() {
2743 if (getAllocateVars().size() != getAllocatorVars().size())
2745 "expected equal sizes for allocate and allocator variables");
2751 getReductionByref());
2754LogicalResult ParallelOp::verifyRegions() {
2755 auto distChildOps = getOps<DistributeOp>();
2756 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2757 if (numDistChildOps > 1)
2759 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2761 if (numDistChildOps == 1) {
2764 <<
"'omp.composite' attribute missing from composite operation";
2766 auto *ompDialect =
getContext()->getLoadedDialect<OpenMPDialect>();
2767 Operation &distributeOp = **distChildOps.begin();
2769 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2773 return emitError() <<
"unexpected OpenMP operation inside of composite "
2775 << childOp.getName();
2777 }
else if (isComposite()) {
2779 <<
"'omp.composite' attribute present in non-composite operation";
2796 const TeamsOperands &clauses) {
2800 builder, state, clauses.allocateVars, clauses.allocatorVars,
2801 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
2803 nullptr, clauses.reductionMod,
2804 clauses.reductionVars,
2806 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2813 if (numTeamsLower) {
2814 if (numTeamsUpperVars.size() != 1)
2816 "expected exactly one num_teams upper bound when lower bound is "
2820 "expected num_teams upper bound and lower bound to be "
2827LogicalResult TeamsOp::verify() {
2836 return emitError(
"expected to be nested inside of omp.target or not nested "
2837 "in any OpenMP dialect operations");
2841 this->getNumTeamsUpperVars())))
2845 if (getAllocateVars().size() != getAllocatorVars().size())
2847 "expected equal sizes for allocate and allocator variables");
2850 getReductionByref());
2858 return getParentOp().getPrivateVars();
2862 return getParentOp().getReductionVars();
2870 const SectionsOperands &clauses) {
2873 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2876 clauses.reductionMod, clauses.reductionVars,
2881LogicalResult SectionsOp::verify() {
2882 if (getAllocateVars().size() != getAllocatorVars().size())
2884 "expected equal sizes for allocate and allocator variables");
2887 getReductionByref());
2890LogicalResult SectionsOp::verifyRegions() {
2891 for (
auto &inst : *getRegion().begin()) {
2892 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2894 <<
"expected omp.section op or terminator op inside region";
2906 const SingleOperands &clauses) {
2909 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2910 clauses.copyprivateVars,
2911 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2916LogicalResult SingleOp::verify() {
2918 if (getAllocateVars().size() != getAllocatorVars().size())
2920 "expected equal sizes for allocate and allocator variables");
2923 getCopyprivateSyms());
2931 const WorkshareOperands &clauses) {
2932 WorkshareOp::build(builder, state, clauses.nowait);
2939LogicalResult WorkshareLoopWrapperOp::verify() {
2940 if (!(*this)->getParentOfType<WorkshareOp>())
2941 return emitOpError() <<
"must be nested in an omp.workshare";
2945LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2946 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2948 return emitOpError() <<
"expected to be a standalone loop wrapper";
2957LogicalResult LoopWrapperInterface::verifyImpl() {
2961 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2962 "and `SingleBlock` traits";
2965 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2968 if (range_size(region.
getOps()) != 1)
2970 <<
"loop wrapper does not contain exactly one nested op";
2973 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2974 return emitOpError() <<
"nested in loop wrapper is not another loop "
2975 "wrapper or `omp.loop_nest`";
2985 const LoopOperands &clauses) {
2988 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2990 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2991 clauses.reductionMod, clauses.reductionVars,
2996LogicalResult LoopOp::verify() {
2998 getReductionByref());
3001LogicalResult LoopOp::verifyRegions() {
3002 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3004 return emitOpError() <<
"expected to be a standalone loop wrapper";
3015 build(builder, state, {}, {},
3018 false,
nullptr,
nullptr,
3019 nullptr, {},
nullptr,
3030 const WsloopOperands &clauses) {
3035 {}, {}, clauses.linearVars,
3036 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3037 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3038 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3039 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3041 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3042 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3045LogicalResult WsloopOp::verify() {
3049 if (getLinearVars().size() &&
3050 getLinearVarTypes().value().size() != getLinearVars().size())
3051 return emitError() <<
"Ill-formed type attributes for linear variables";
3053 getReductionByref());
3056LogicalResult WsloopOp::verifyRegions() {
3057 bool isCompositeChildLeaf =
3058 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3060 if (LoopWrapperInterface nested = getNestedWrapper()) {
3063 <<
"'omp.composite' attribute missing from composite wrapper";
3067 if (!isa<SimdOp>(nested))
3068 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3070 }
else if (isComposite() && !isCompositeChildLeaf) {
3072 <<
"'omp.composite' attribute present in non-composite wrapper";
3073 }
else if (!isComposite() && isCompositeChildLeaf) {
3075 <<
"'omp.composite' attribute missing from composite wrapper";
3086 const SimdOperands &clauses) {
3088 SimdOp::build(builder, state, clauses.alignedVars,
3090 clauses.linearVars, clauses.linearStepVars,
3091 clauses.linearVarTypes, clauses.linearModifiers,
3092 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3093 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3094 clauses.privateNeedsBarrier, clauses.reductionMod,
3095 clauses.reductionVars,
3101LogicalResult SimdOp::verify() {
3102 if (getSimdlen().has_value() && getSafelen().has_value() &&
3103 getSimdlen().value() > getSafelen().value())
3105 <<
"simdlen clause and safelen clause are both present, but the "
3106 "simdlen value is not less than or equal to safelen value";
3118 bool isCompositeChildLeaf =
3119 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3121 if (!isComposite() && isCompositeChildLeaf)
3123 <<
"'omp.composite' attribute missing from composite wrapper";
3125 if (isComposite() && !isCompositeChildLeaf)
3127 <<
"'omp.composite' attribute present in non-composite wrapper";
3131 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3133 for (
const Attribute &sym : *privateSyms) {
3134 auto symRef = cast<SymbolRefAttr>(sym);
3135 omp::PrivateClauseOp privatizer =
3137 getOperation(), symRef);
3139 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
3140 if (privatizer.getDataSharingType() ==
3141 DataSharingClauseType::FirstPrivate)
3142 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
3146 if (getLinearVars().size() &&
3147 getLinearVarTypes().value().size() != getLinearVars().size())
3148 return emitError() <<
"Ill-formed type attributes for linear variables";
3152LogicalResult SimdOp::verifyRegions() {
3153 if (getNestedWrapper())
3154 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
3164 const DistributeOperands &clauses) {
3165 DistributeOp::build(builder, state, clauses.allocateVars,
3166 clauses.allocatorVars, clauses.distScheduleStatic,
3167 clauses.distScheduleChunkSize, clauses.order,
3168 clauses.orderMod, clauses.privateVars,
3170 clauses.privateNeedsBarrier);
3173LogicalResult DistributeOp::verify() {
3174 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3176 "dist_schedule_static being present";
3178 if (getAllocateVars().size() != getAllocatorVars().size())
3180 "expected equal sizes for allocate and allocator variables");
3185LogicalResult DistributeOp::verifyRegions() {
3186 if (LoopWrapperInterface nested = getNestedWrapper()) {
3189 <<
"'omp.composite' attribute missing from composite wrapper";
3192 if (isa<WsloopOp>(nested)) {
3194 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3195 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3196 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
3197 "when a composite 'omp.parallel' is the direct "
3200 }
else if (!isa<SimdOp>(nested))
3201 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
3203 }
else if (isComposite()) {
3205 <<
"'omp.composite' attribute present in non-composite wrapper";
3215LogicalResult DeclareMapperInfoOp::verify() {
3219LogicalResult DeclareMapperOp::verifyRegions() {
3220 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3221 getRegion().getBlocks().front().getTerminator()))
3222 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
3231LogicalResult DeclareReductionOp::verifyRegions() {
3232 if (!getAllocRegion().empty()) {
3233 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3234 if (yieldOp.getResults().size() != 1 ||
3235 yieldOp.getResults().getTypes()[0] !=
getType())
3236 return emitOpError() <<
"expects alloc region to yield a value "
3237 "of the reduction type";
3241 if (getInitializerRegion().empty())
3242 return emitOpError() <<
"expects non-empty initializer region";
3243 Block &initializerEntryBlock = getInitializerRegion().
front();
3246 if (!getAllocRegion().empty())
3247 return emitOpError() <<
"expects two arguments to the initializer region "
3248 "when an allocation region is used";
3250 if (getAllocRegion().empty())
3251 return emitOpError() <<
"expects one argument to the initializer region "
3252 "when no allocation region is used";
3255 <<
"expects one or two arguments to the initializer region";
3259 if (arg.getType() !=
getType())
3260 return emitOpError() <<
"expects initializer region argument to match "
3261 "the reduction type";
3263 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3264 if (yieldOp.getResults().size() != 1 ||
3265 yieldOp.getResults().getTypes()[0] !=
getType())
3266 return emitOpError() <<
"expects initializer region to yield a value "
3267 "of the reduction type";
3270 if (getReductionRegion().empty())
3271 return emitOpError() <<
"expects non-empty reduction region";
3272 Block &reductionEntryBlock = getReductionRegion().
front();
3277 return emitOpError() <<
"expects reduction region with two arguments of "
3278 "the reduction type";
3279 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3280 if (yieldOp.getResults().size() != 1 ||
3281 yieldOp.getResults().getTypes()[0] !=
getType())
3282 return emitOpError() <<
"expects reduction region to yield a value "
3283 "of the reduction type";
3286 if (!getAtomicReductionRegion().empty()) {
3287 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
3291 return emitOpError() <<
"expects atomic reduction region with two "
3292 "arguments of the same type";
3293 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3296 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
3297 return emitOpError() <<
"expects atomic reduction region arguments to "
3298 "be accumulators containing the reduction type";
3301 if (getCleanupRegion().empty())
3303 Block &cleanupEntryBlock = getCleanupRegion().
front();
3306 return emitOpError() <<
"expects cleanup region with one argument "
3307 "of the reduction type";
3317 const TaskOperands &clauses) {
3320 builder, state, clauses.iterated, clauses.affinityVars,
3321 clauses.allocateVars, clauses.allocatorVars,
3322 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3323 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3324 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3326 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3327 clauses.priority, clauses.privateVars,
3329 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3332LogicalResult TaskOp::verify() {
3333 LogicalResult verifyDependVars =
3335 getDependIteratedKinds(), getDependIterated());
3336 return failed(verifyDependVars)
3339 getInReductionVars(),
3340 getInReductionByref());
3348 const TaskgroupOperands &clauses) {
3350 TaskgroupOp::build(builder, state, clauses.allocateVars,
3351 clauses.allocatorVars, clauses.taskReductionVars,
3356LogicalResult TaskgroupOp::verify() {
3358 getTaskReductionVars(),
3359 getTaskReductionByref());
3367 const TaskloopOperands &clauses) {
3370 builder, state, clauses.allocateVars, clauses.allocatorVars,
3371 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3372 clauses.inReductionVars,
3374 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3375 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3376 clauses.privateVars,
3378 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3383LogicalResult TaskloopOp::verify() {
3384 if (getAllocateVars().size() != getAllocatorVars().size())
3386 "expected equal sizes for allocate and allocator variables");
3388 getReductionVars(), getReductionByref())) ||
3390 getInReductionVars(),
3391 getInReductionByref())))
3394 if (!getReductionVars().empty() && getNogroup())
3395 return emitError(
"if a reduction clause is present on the taskloop "
3396 "directive, the nogroup clause must not be specified");
3397 for (
auto var : getReductionVars()) {
3398 if (llvm::is_contained(getInReductionVars(), var))
3399 return emitError(
"the same list item cannot appear in both a reduction "
3400 "and an in_reduction clause");
3403 if (getGrainsize() && getNumTasks()) {
3405 "the grainsize clause and num_tasks clause are mutually exclusive and "
3406 "may not appear on the same taskloop directive");
3412LogicalResult TaskloopOp::verifyRegions() {
3413 if (LoopWrapperInterface nested = getNestedWrapper()) {
3416 <<
"'omp.composite' attribute missing from composite wrapper";
3420 if (!isa<SimdOp>(nested))
3421 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3422 }
else if (isComposite()) {
3424 <<
"'omp.composite' attribute present in non-composite wrapper";
3448 for (
auto &iv : ivs)
3449 iv.type = loopVarType;
3454 result.addAttribute(
"loop_inclusive", UnitAttr::get(ctx));
3470 "collapse_num_loops",
3475 auto parseTiles = [&]() -> ParseResult {
3479 tiles.push_back(
tile);
3488 if (tiles.size() > 0)
3507 Region ®ion = getRegion();
3509 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3510 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3511 if (getLoopInclusive())
3513 p <<
"step (" << getLoopSteps() <<
") ";
3514 if (
int64_t numCollapse = getCollapseNumLoops())
3515 if (numCollapse > 1)
3516 p <<
"collapse(" << numCollapse <<
") ";
3519 p <<
"tiles(" << tiles.value() <<
") ";
3525 const LoopNestOperands &clauses) {
3527 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3528 clauses.loopLowerBounds, clauses.loopUpperBounds,
3529 clauses.loopSteps, clauses.loopInclusive,
3533LogicalResult LoopNestOp::verify() {
3534 if (getLoopLowerBounds().empty())
3535 return emitOpError() <<
"must represent at least one loop";
3537 if (getLoopLowerBounds().size() != getIVs().size())
3538 return emitOpError() <<
"number of range arguments and IVs do not match";
3540 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3541 if (lb.getType() != iv.getType())
3543 <<
"range argument type does not match corresponding IV type";
3546 uint64_t numIVs = getIVs().size();
3548 if (
const auto &numCollapse = getCollapseNumLoops())
3549 if (numCollapse > numIVs)
3551 <<
"collapse value is larger than the number of loops";
3554 if (tiles.value().size() > numIVs)
3555 return emitOpError() <<
"too few canonical loops for tile dimensions";
3557 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3558 return emitOpError() <<
"expects parent op to be a loop wrapper";
3563void LoopNestOp::gatherWrappers(
3566 while (
auto wrapper =
3567 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3568 wrappers.push_back(wrapper);
3577std::tuple<NewCliOp, OpOperand *, OpOperand *>
3583 return {{},
nullptr,
nullptr};
3586 "Unexpected type of cli");
3592 auto op = cast<LoopTransformationInterface>(use.getOwner());
3594 unsigned opnum = use.getOperandNumber();
3595 if (op.isGeneratee(opnum)) {
3596 assert(!gen &&
"Each CLI may have at most one def");
3598 }
else if (op.isApplyee(opnum)) {
3599 assert(!cons &&
"Each CLI may have at most one consumer");
3602 llvm_unreachable(
"Unexpected operand for a CLI");
3606 return {create, gen, cons};
3629 std::string cliName{
"cli"};
3633 .Case([&](CanonicalLoopOp op) {
3636 .Case([&](UnrollHeuristicOp op) -> std::string {
3637 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3639 .Case([&](FuseOp op) -> std::string {
3640 unsigned opnum =
generator->getOperandNumber();
3643 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3644 return "canonloop_fuse";
3648 .Case([&](TileOp op) -> std::string {
3649 auto [generateesFirst, generateesCount] =
3650 op.getGenerateesODSOperandIndexAndLength();
3651 unsigned firstGrid = generateesFirst;
3652 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3653 unsigned end = generateesFirst + generateesCount;
3654 unsigned opnum =
generator->getOperandNumber();
3656 if (firstGrid <= opnum && opnum < firstIntratile) {
3657 unsigned gridnum = opnum - firstGrid + 1;
3658 return (
"grid" + Twine(gridnum)).str();
3660 if (firstIntratile <= opnum && opnum < end) {
3661 unsigned intratilenum = opnum - firstIntratile + 1;
3662 return (
"intratile" + Twine(intratilenum)).str();
3664 llvm_unreachable(
"Unexpected generatee argument");
3666 .DefaultUnreachable(
"TODO: Custom name for this operation");
3669 setNameFn(
result, cliName);
3672LogicalResult NewCliOp::verify() {
3673 Value cli = getResult();
3676 "Unexpected type of cli");
3682 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3684 unsigned opnum = use.getOperandNumber();
3685 if (op.isGeneratee(opnum)) {
3688 emitOpError(
"CLI must have at most one generator");
3690 .
append(
"first generator here:");
3692 .
append(
"second generator here:");
3697 }
else if (op.isApplyee(opnum)) {
3700 emitOpError(
"CLI must have at most one consumer");
3702 .
append(
"first consumer here:")
3706 .
append(
"second consumer here:")
3713 llvm_unreachable(
"Unexpected operand for a CLI");
3721 .
append(
"see consumer here: ")
3744 setNameFn(&getRegion().front(),
"body_entry");
3747void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
3755 p <<
'(' << getCli() <<
')';
3756 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
3757 <<
" in range(" << getTripCount() <<
") ";
3767 CanonicalLoopInfoType cliType =
3768 CanonicalLoopInfoType::get(parser.
getContext());
3793 if (parser.
parseRegion(*region, {inductionVariable}))
3798 result.operands.append(cliOperand);
3804 return mlir::success();
3807LogicalResult CanonicalLoopOp::verify() {
3810 if (!getRegion().empty()) {
3811 Region ®ion = getRegion();
3814 "Canonical loop region must have exactly one argument");
3818 "Region argument must be the same type as the trip count");
3824Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
3826std::pair<unsigned, unsigned>
3827CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3832std::pair<unsigned, unsigned>
3833CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3834 return getODSOperandIndexAndLength(odsIndex_cli);
3848 p <<
'(' << getApplyee() <<
')';
3855 auto cliType = CanonicalLoopInfoType::get(parser.
getContext());
3878 return mlir::success();
3881std::pair<unsigned, unsigned>
3882UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3883 return getODSOperandIndexAndLength(odsIndex_applyee);
3886std::pair<unsigned, unsigned>
3887UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3898 if (!generatees.empty())
3899 p <<
'(' << llvm::interleaved(generatees) <<
')';
3901 if (!applyees.empty())
3902 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
3944 bool isOnlyCanonLoops =
true;
3946 for (
Value applyee : op.getApplyees()) {
3947 auto [create, gen, cons] =
decodeCli(applyee);
3950 return op.emitOpError() <<
"applyee CLI has no generator";
3952 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3953 canonLoops.push_back(loop);
3955 isOnlyCanonLoops =
false;
3960 if (!isOnlyCanonLoops)
3964 for (
auto i : llvm::seq<int>(1, canonLoops.size())) {
3965 auto parentLoop = canonLoops[i - 1];
3966 auto loop = canonLoops[i];
3968 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
3969 return op.emitOpError()
3970 <<
"tiled loop nest must be nested within each other";
3972 parentIVs.insert(parentLoop.getInductionVar());
3977 bool isPerfectlyNested = [&]() {
3978 auto &parentBody = parentLoop.getRegion();
3979 if (!parentBody.hasOneBlock())
3981 auto &parentBlock = parentBody.getBlocks().front();
3983 auto nestedLoopIt = parentBlock.begin();
3984 if (nestedLoopIt == parentBlock.end() ||
3985 (&*nestedLoopIt != loop.getOperation()))
3988 auto termIt = std::next(nestedLoopIt);
3989 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3992 if (std::next(termIt) != parentBlock.end())
3997 if (!isPerfectlyNested)
3998 return op.emitOpError() <<
"tiled loop nest must be perfectly nested";
4000 if (parentIVs.contains(loop.getTripCount()))
4001 return op.emitOpError() <<
"tiled loop nest must be rectangular";
4018LogicalResult TileOp::verify() {
4019 if (getApplyees().empty())
4020 return emitOpError() <<
"must apply to at least one loop";
4022 if (getSizes().size() != getApplyees().size())
4023 return emitOpError() <<
"there must be one tile size for each applyee";
4025 if (!getGeneratees().empty() &&
4026 2 * getSizes().size() != getGeneratees().size())
4028 <<
"expecting two times the number of generatees than applyees";
4033std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4034 return getODSOperandIndexAndLength(odsIndex_applyees);
4037std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4038 return getODSOperandIndexAndLength(odsIndex_generatees);
4048 if (!generatees.empty())
4049 p <<
'(' << llvm::interleaved(generatees) <<
')';
4051 if (!applyees.empty())
4052 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4055LogicalResult FuseOp::verify() {
4056 if (getApplyees().size() < 2)
4057 return emitOpError() <<
"must apply to at least two loops";
4059 if (getFirst().has_value() && getCount().has_value()) {
4060 int64_t first = getFirst().value();
4061 int64_t count = getCount().value();
4062 if ((
unsigned)(first + count - 1) > getApplyees().size())
4063 return emitOpError() <<
"the numbers of applyees must be at least first "
4064 "minus one plus count attributes";
4065 if (!getGeneratees().empty() &&
4066 getGeneratees().size() != getApplyees().size() + 1 - count)
4067 return emitOpError() <<
"the number of generatees must be the number of "
4068 "aplyees plus one minus count";
4071 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4073 <<
"in a complete fuse the number of generatees must be exactly 1";
4075 for (
auto &&applyee : getApplyees()) {
4076 auto [create, gen, cons] =
decodeCli(applyee);
4079 return emitOpError() <<
"applyee CLI has no generator";
4080 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4083 <<
"currently only supports omp.canonical_loop as applyee";
4087std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4088 return getODSOperandIndexAndLength(odsIndex_applyees);
4091std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4092 return getODSOperandIndexAndLength(odsIndex_generatees);
4100 const CriticalDeclareOperands &clauses) {
4101 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4104LogicalResult CriticalDeclareOp::verify() {
4109 if (getNameAttr()) {
4110 SymbolRefAttr symbolRef = getNameAttr();
4114 return emitOpError() <<
"expected symbol reference " << symbolRef
4115 <<
" to point to a critical declaration";
4135 return op.
emitOpError() <<
"must be nested inside of a loop";
4139 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4140 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4142 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
4143 "have an ordered clause";
4145 if (hasRegion && orderedAttr.getInt() != 0)
4146 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
4147 "have a parameter present";
4149 if (!hasRegion && orderedAttr.getInt() == 0)
4150 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
4151 "have a parameter present";
4152 }
else if (!isa<SimdOp>(wrapper)) {
4153 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
4154 "or worksharing simd loop";
4160 const OrderedOperands &clauses) {
4161 OrderedOp::build(builder, state, clauses.doacrossDependType,
4162 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4165LogicalResult OrderedOp::verify() {
4169 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4170 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4171 return emitOpError() <<
"number of variables in depend clause does not "
4172 <<
"match number of iteration variables in the "
4179 const OrderedRegionOperands &clauses) {
4180 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4190 const TaskwaitOperands &clauses) {
4192 TaskwaitOp::build(builder, state,
nullptr,
4201LogicalResult AtomicReadOp::verify() {
4202 if (verifyCommon().
failed())
4203 return mlir::failure();
4205 if (
auto mo = getMemoryOrder()) {
4206 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4207 *mo == ClauseMemoryOrderKind::Release) {
4209 "memory-order must not be acq_rel or release for atomic reads");
4219LogicalResult AtomicWriteOp::verify() {
4220 if (verifyCommon().
failed())
4221 return mlir::failure();
4223 if (
auto mo = getMemoryOrder()) {
4224 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4225 *mo == ClauseMemoryOrderKind::Acquire) {
4227 "memory-order must not be acq_rel or acquire for atomic writes");
4237LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4243 if (
Value writeVal = op.getWriteOpVal()) {
4245 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4251LogicalResult AtomicUpdateOp::verify() {
4252 if (verifyCommon().
failed())
4253 return mlir::failure();
4255 if (
auto mo = getMemoryOrder()) {
4256 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4257 *mo == ClauseMemoryOrderKind::Acquire) {
4259 "memory-order must not be acq_rel or acquire for atomic updates");
4266LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4272AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4273 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4275 return dyn_cast<AtomicReadOp>(getSecondOp());
4278AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4279 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4281 return dyn_cast<AtomicWriteOp>(getSecondOp());
4284AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4285 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4287 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4290LogicalResult AtomicCaptureOp::verify() {
4294LogicalResult AtomicCaptureOp::verifyRegions() {
4295 if (verifyRegionsCommon().
failed())
4296 return mlir::failure();
4298 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
4300 "operations inside capture region must not have hint clause");
4302 if (getFirstOp()->getAttr(
"memory_order") ||
4303 getSecondOp()->getAttr(
"memory_order"))
4305 "operations inside capture region must not have memory_order clause");
4314 const CancelOperands &clauses) {
4315 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4328LogicalResult CancelOp::verify() {
4329 ClauseCancellationConstructType cct = getCancelDirective();
4332 if (!structuralParent)
4333 return emitOpError() <<
"Orphaned cancel construct";
4335 if ((cct == ClauseCancellationConstructType::Parallel) &&
4336 !mlir::isa<ParallelOp>(structuralParent)) {
4337 return emitOpError() <<
"cancel parallel must appear "
4338 <<
"inside a parallel region";
4340 if (cct == ClauseCancellationConstructType::Loop) {
4343 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
4347 <<
"cancel loop must appear inside a worksharing-loop region";
4349 if (wsloopOp.getNowaitAttr()) {
4350 return emitError() <<
"A worksharing construct that is canceled "
4351 <<
"must not have a nowait clause";
4353 if (wsloopOp.getOrderedAttr()) {
4354 return emitError() <<
"A worksharing construct that is canceled "
4355 <<
"must not have an ordered clause";
4358 }
else if (cct == ClauseCancellationConstructType::Sections) {
4362 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
4364 return emitOpError() <<
"cancel sections must appear "
4365 <<
"inside a sections region";
4367 if (sectionsOp.getNowait()) {
4368 return emitError() <<
"A sections construct that is canceled "
4369 <<
"must not have a nowait clause";
4372 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4373 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4374 !mlir::isa<omp::TaskloopOp>(structuralParent->
getParentOp()))) {
4375 return emitOpError() <<
"cancel taskgroup must appear "
4376 <<
"inside a task region";
4386 const CancellationPointOperands &clauses) {
4387 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4390LogicalResult CancellationPointOp::verify() {
4391 ClauseCancellationConstructType cct = getCancelDirective();
4394 if (!structuralParent)
4395 return emitOpError() <<
"Orphaned cancellation point";
4397 if ((cct == ClauseCancellationConstructType::Parallel) &&
4398 !mlir::isa<ParallelOp>(structuralParent)) {
4399 return emitOpError() <<
"cancellation point parallel must appear "
4400 <<
"inside a parallel region";
4404 if ((cct == ClauseCancellationConstructType::Loop) &&
4405 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
4406 return emitOpError() <<
"cancellation point loop must appear "
4407 <<
"inside a worksharing-loop region";
4409 if ((cct == ClauseCancellationConstructType::Sections) &&
4410 !mlir::isa<omp::SectionOp>(structuralParent)) {
4411 return emitOpError() <<
"cancellation point sections must appear "
4412 <<
"inside a sections region";
4414 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4415 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4416 !mlir::isa<omp::TaskloopOp>(structuralParent->
getParentOp()))) {
4417 return emitOpError() <<
"cancellation point taskgroup must appear "
4418 <<
"inside a task region";
4427LogicalResult MapBoundsOp::verify() {
4428 auto extent = getExtent();
4430 if (!extent && !upperbound)
4431 return emitError(
"expected extent or upperbound.");
4438 PrivateClauseOp::build(
4439 odsBuilder, odsState, symName, type,
4440 DataSharingClauseTypeAttr::get(odsBuilder.
getContext(),
4441 DataSharingClauseType::Private));
4444LogicalResult PrivateClauseOp::verifyRegions() {
4445 Type argType = getArgType();
4446 auto verifyTerminator = [&](
Operation *terminator,
4447 bool yieldsValue) -> LogicalResult {
4451 if (!llvm::isa<YieldOp>(terminator))
4453 <<
"expected exit block terminator to be an `omp.yield` op.";
4455 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4456 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4459 if (yieldedTypes.empty())
4463 <<
"Did not expect any values to be yielded.";
4466 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4470 <<
"Invalid yielded value. Expected type: " << argType
4473 if (yieldedTypes.empty())
4476 error << yieldedTypes;
4482 StringRef regionName,
4483 bool yieldsValue) -> LogicalResult {
4484 assert(!region.
empty());
4488 <<
"`" << regionName <<
"`: "
4489 <<
"expected " << expectedNumArgs
4492 for (
Block &block : region) {
4494 if (!block.mightHaveTerminator())
4497 if (
failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4505 for (
Region *region : getRegions())
4506 for (
Type ty : region->getArgumentTypes())
4508 return emitError() <<
"Region argument type mismatch: got " << ty
4509 <<
" expected " << argType <<
".";
4512 if (!initRegion.
empty() &&
4517 DataSharingClauseType dsType = getDataSharingType();
4519 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4520 return emitError(
"`private` clauses do not require a `copy` region.");
4522 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4524 "`firstprivate` clauses require at least a `copy` region.");
4526 if (dsType == DataSharingClauseType::FirstPrivate &&
4531 if (!getDeallocRegion().empty() &&
4544 const MaskedOperands &clauses) {
4545 MaskedOp::build(builder, state, clauses.filteredThreadId);
4553 const ScanOperands &clauses) {
4554 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4557LogicalResult ScanOp::verify() {
4558 if (hasExclusiveVars() == hasInclusiveVars())
4560 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4561 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4562 if (parentWsLoopOp.getReductionModAttr() &&
4563 parentWsLoopOp.getReductionModAttr().getValue() ==
4564 ReductionModifier::inscan)
4567 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4568 if (parentSimdOp.getReductionModAttr() &&
4569 parentSimdOp.getReductionModAttr().getValue() ==
4570 ReductionModifier::inscan)
4573 return emitError(
"SCAN directive needs to be enclosed within a parent "
4574 "worksharing loop construct or SIMD construct with INSCAN "
4575 "reduction modifier");
4580LogicalResult AllocateDirOp::verify() {
4581 std::optional<uint64_t> align = this->getAlign();
4583 if (align.has_value()) {
4584 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4585 return emitError() <<
"ALIGN value : " << align.value()
4586 <<
" must be power of 2";
4596mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4597 return getInTypeAttr().getValue();
4606 bool hasOperands =
false;
4607 std::int32_t typeparamsSize = 0;
4613 return mlir::failure();
4615 return mlir::failure();
4617 return mlir::failure();
4621 return mlir::failure();
4622 result.addAttribute(
"in_type", mlir::TypeAttr::get(intype));
4629 return mlir::failure();
4630 typeparamsSize = operands.size();
4633 std::int32_t shapeSize = 0;
4637 return mlir::failure();
4638 shapeSize = operands.size() - typeparamsSize;
4640 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4641 typeVec.push_back(idxTy);
4647 return mlir::failure();
4652 return mlir::failure();
4655 result.addAttribute(
"operandSegmentSizes",
4659 return mlir::failure();
4660 return mlir::success();
4675 if (!getTypeparams().empty()) {
4676 p <<
'(' << getTypeparams() <<
" : " << getTypeparams().getTypes() <<
')';
4683 {
"in_type",
"operandSegmentSizes"});
4686llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4688 if (!mlir::dyn_cast<IntegerType>(outType))
4690 return mlir::success();
4697LogicalResult WorkdistributeOp::verify() {
4699 Region ®ion = getRegion();
4704 if (entryBlock.
empty())
4705 return emitOpError(
"region must contain a structured block");
4707 bool hasTerminator =
false;
4708 for (
Block &block : region) {
4709 if (isa<TerminatorOp>(block.back())) {
4710 if (hasTerminator) {
4711 return emitOpError(
"region must have exactly one terminator");
4713 hasTerminator =
true;
4716 if (!hasTerminator) {
4717 return emitOpError(
"region must be terminated with omp.terminator");
4721 if (isa<BarrierOp>(op)) {
4723 "explicit barriers are not allowed in workdistribute region");
4726 if (isa<ParallelOp>(op)) {
4728 "nested parallel constructs not allowed in workdistribute");
4730 if (isa<TeamsOp>(op)) {
4732 "nested teams constructs not allowed in workdistribute");
4736 if (walkResult.wasInterrupted())
4740 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4741 return emitOpError(
"workdistribute must be nested under teams");
4749LogicalResult DeclareSimdOp::verify() {
4752 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4754 return emitOpError() <<
"must be nested inside a function";
4756 if (getInbranch() && getNotinbranch())
4757 return emitOpError(
"cannot have both 'inbranch' and 'notinbranch'");
4767 const DeclareSimdOperands &clauses) {
4769 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4771 clauses.linearVars, clauses.linearStepVars,
4772 clauses.linearVarTypes, clauses.linearModifiers,
4773 clauses.notinbranch, clauses.simdlen,
4774 clauses.uniformVars);
4791 return mlir::failure();
4792 return mlir::success();
4799 for (
unsigned i = 0; i < uniformVars.size(); ++i) {
4802 p << uniformVars[i] <<
" : " << uniformTypes[i];
4817 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
4818 [&]() -> ParseResult {
return success(); })))
4852 OpAsmParser::Argument &arg = ivArgs.emplace_back();
4853 if (parser.parseArgument(arg))
4857 if (succeeded(parser.parseOptionalColon())) {
4858 if (parser.parseType(arg.type))
4861 arg.type = parser.getBuilder().getIndexType();
4873 OpAsmParser::UnresolvedOperand lb, ub, st;
4874 if (parser.parseOperand(lb) || parser.parseKeyword(
"to") ||
4875 parser.parseOperand(ub) || parser.parseKeyword(
"step") ||
4876 parser.parseOperand(st))
4881 steps.push_back(st);
4889 if (ivArgs.size() != lbs.size())
4891 <<
"mismatch: " << ivArgs.size() <<
" variables but " << lbs.size()
4894 for (
auto &arg : ivArgs) {
4895 lbTypes.push_back(arg.type);
4896 ubTypes.push_back(arg.type);
4897 stepTypes.push_back(arg.type);
4917 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
4920 p << lbs[i] <<
" to " << ubs[i] <<
" step " << steps[i];
4928LogicalResult IteratorOp::verify() {
4929 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().
getType());
4931 return emitOpError() <<
"result must be omp.iterated<entry_ty>";
4933 for (
auto [lb,
ub, step] : llvm::zip_equal(
4934 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
4936 return emitOpError() <<
"loop step must not be zero";
4940 IntegerAttr stepAttr;
4946 const APInt &lbVal = lbAttr.getValue();
4947 const APInt &ubVal = ubAttr.getValue();
4948 const APInt &stepVal = stepAttr.getValue();
4949 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
4950 return emitOpError() <<
"positive loop step requires lower bound to be "
4951 "less than or equal to upper bound";
4952 if (stepVal.isNegative() && lbVal.slt(ubVal))
4953 return emitOpError() <<
"negative loop step requires lower bound to be "
4954 "greater than or equal to upper bound";
4957 Block &
b = getRegion().front();
4958 auto yield = llvm::dyn_cast<omp::YieldOp>(
b.getTerminator());
4961 return emitOpError() <<
"region must be terminated by omp.yield";
4963 if (yield.getNumOperands() != 1)
4965 <<
"omp.yield in omp.iterator region must yield exactly one value";
4967 mlir::Type yieldedTy = yield.getOperand(0).getType();
4968 mlir::Type elemTy = iteratedTy.getElementType();
4970 if (yieldedTy != elemTy)
4971 return emitOpError() <<
"omp.iterated element type (" << elemTy
4972 <<
") does not match omp.yield operand type ("
4973 << yieldedTy <<
")";
4978#define GET_ATTRDEF_CLASSES
4979#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4981#define GET_OP_CLASSES
4982#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4984#define GET_TYPEDEF_CLASSES
4985#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.