27#include "llvm/ADT/ArrayRef.h"
28#include "llvm/ADT/PostOrderIterator.h"
29#include "llvm/ADT/STLExtras.h"
30#include "llvm/ADT/STLForwardCompat.h"
31#include "llvm/ADT/SmallString.h"
32#include "llvm/ADT/StringExtras.h"
33#include "llvm/ADT/StringRef.h"
34#include "llvm/ADT/TypeSwitch.h"
35#include "llvm/ADT/bit.h"
36#include "llvm/Support/InterleavedRange.h"
42#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
43#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
52 return attrs.empty() ?
nullptr : ArrayAttr::get(context, attrs);
66struct MemRefPointerLikeModel
67 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
70 return llvm::cast<MemRefType>(pointer).getElementType();
74struct LLVMPointerPointerLikeModel
75 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76 LLVM::LLVMPointerType> {
101 bool isRegionArgOfOp;
111 assert(isRegionArgOfOp &&
"Must describe a region operand");
114 size_t &getArgIdx() {
115 assert(isRegionArgOfOp &&
"Must describe a region operand");
120 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
124 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
127 bool isLoopOp()
const {
128 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
129 return isa<CanonicalLoopOp>(op);
131 Region *&getParentRegion() {
132 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
135 size_t &getLoopDepth() {
136 assert(!isRegionArgOfOp &&
"Must describe a operation of a region");
140 void skipIf(
bool v =
true) { skip = skip || v; }
158 llvm::ReversePostOrderTraversal<Block *> traversal(&r->
getBlocks().front());
161 size_t sequentialIdx = -1;
162 bool isOnlyContainerOp =
true;
163 for (
Block *
b : traversal) {
165 if (&op == o && !found) {
169 if (op.getNumRegions()) {
172 isOnlyContainerOp =
false;
174 if (found && !isOnlyContainerOp)
179 Component &containerOpInRegion = components.emplace_back();
180 containerOpInRegion.isRegionArgOfOp =
false;
181 containerOpInRegion.isUnique = isOnlyContainerOp;
182 containerOpInRegion.getContainerOp() = o;
183 containerOpInRegion.getOpPos() = sequentialIdx;
184 containerOpInRegion.getParentRegion() = r;
189 Component ®ionArgOfOperation = components.emplace_back();
190 regionArgOfOperation.isRegionArgOfOp =
true;
191 regionArgOfOperation.isUnique =
true;
192 regionArgOfOperation.getArgIdx() = 0;
193 regionArgOfOperation.getOwnerOp() = parent;
205 for (
auto [idx, region] : llvm::enumerate(o->
getRegions())) {
209 llvm_unreachable(
"Region not child of its parent operation");
211 regionArgOfOperation.isUnique =
false;
212 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
220 for (Component &c : components)
221 c.skipIf(c.isRegionArgOfOp && c.isUnique);
224 size_t numSurroundingLoops = 0;
225 for (Component &c : llvm::reverse(components)) {
230 if (c.isRegionArgOfOp) {
231 numSurroundingLoops = 0;
238 numSurroundingLoops = 0;
240 c.getLoopDepth() = numSurroundingLoops;
243 if (isa<CanonicalLoopOp>(c.getContainerOp()))
244 numSurroundingLoops += 1;
249 bool isLoopNest =
false;
250 for (Component &c : components) {
251 if (c.skip || c.isRegionArgOfOp)
254 if (!isLoopNest && c.getLoopDepth() >= 1) {
257 }
else if (isLoopNest) {
259 c.skipIf(c.isUnique);
263 if (c.getLoopDepth() == 0)
270 for (Component &c : components)
271 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
272 !isa<CanonicalLoopOp>(c.getContainerOp()));
276 bool newRegion =
true;
277 for (Component &c : llvm::reverse(components)) {
278 c.skipIf(newRegion && c.isUnique);
285 if (!c.isRegionArgOfOp && c.getContainerOp())
291 llvm::raw_svector_ostream NameOS(Name);
292 for (
auto &c : llvm::reverse(components)) {
296 if (c.isRegionArgOfOp)
297 NameOS <<
"_r" << c.getArgIdx();
298 else if (c.getLoopDepth() >= 1)
299 NameOS <<
"_d" << c.getLoopDepth();
301 NameOS <<
"_s" << c.getOpPos();
304 return NameOS.str().str();
307void OpenMPDialect::initialize() {
310#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
313#define GET_ATTRDEF_LIST
314#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
317#define GET_TYPEDEF_LIST
318#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
321 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
323 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
324 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
329 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
335 mlir::LLVM::GlobalOp::attachInterface<
338 mlir::LLVM::LLVMFuncOp::attachInterface<
341 mlir::func::FuncOp::attachInterface<
367 allocatorVars.push_back(operand);
368 allocatorTypes.push_back(type);
374 allocateVars.push_back(operand);
375 allocateTypes.push_back(type);
386 for (
unsigned i = 0; i < allocateVars.size(); ++i) {
387 std::string separator = i == allocateVars.size() - 1 ?
"" :
", ";
388 p << allocatorVars[i] <<
" : " << allocatorTypes[i] <<
" -> ";
389 p << allocateVars[i] <<
" : " << allocateTypes[i] << separator;
397template <
typename ClauseAttr>
399 using ClauseT =
decltype(std::declval<ClauseAttr>().getValue());
404 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
405 attr = ClauseAttr::get(parser.
getContext(), *enumValue);
408 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
411template <
typename ClauseAttr>
413 p << stringifyEnum(attr.getValue());
438 std::optional<omp::LinearModifier> linearModifier;
440 linearModifier = omp::LinearModifier::val;
442 linearModifier = omp::LinearModifier::ref;
444 linearModifier = omp::LinearModifier::uval;
447 bool hasLinearModifierParens = linearModifier.has_value();
448 if (hasLinearModifierParens && parser.
parseLParen())
456 if (hasLinearModifierParens && parser.
parseRParen())
459 linearVars.push_back(var);
460 linearTypes.push_back(type);
461 linearStepVars.push_back(stepVar);
462 linearStepTypes.push_back(stepType);
463 if (linearModifier) {
465 omp::LinearModifierAttr::get(parser.
getContext(), *linearModifier));
467 modifiers.push_back(UnitAttr::get(parser.
getContext()));
473 linearModifiers = ArrayAttr::get(parser.
getContext(), modifiers);
482 size_t linearVarsSize = linearVars.size();
483 for (
unsigned i = 0; i < linearVarsSize; ++i) {
487 Attribute modAttr = linearModifiers ? linearModifiers[i] :
nullptr;
488 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) :
nullptr;
490 p << omp::stringifyLinearModifier(mod.getValue()) <<
"(";
492 p << linearVars[i] <<
" : " << linearTypes[i];
493 p <<
" = " << linearStepVars[i] <<
" : " << stepVarTypes[i];
509 if (!linearModifiers)
511 if (linearModifiers->size() != linearVars.size())
513 <<
"expected as many linear modifiers as linear variables";
514 if (!isDeclareSimd) {
515 for (
Attribute attr : *linearModifiers) {
518 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
521 omp::LinearModifier mod = modAttr.getValue();
522 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
524 <<
"linear modifier '" << omp::stringifyLinearModifier(mod)
525 <<
"' may only be specified on a declare simd directive";
540 for (
const auto &it : nontemporalVars)
541 if (!nontemporalItems.insert(it).second)
542 return op->
emitOpError() <<
"nontemporal variable used more than once";
551 std::optional<ArrayAttr> alignments,
554 if (!alignedVars.empty()) {
555 if (!alignments || alignments->size() != alignedVars.size())
557 <<
"expected as many alignment values as aligned variables";
560 return op->
emitOpError() <<
"unexpected alignment values attribute";
566 for (
auto it : alignedVars)
567 if (!alignedItems.insert(it).second)
568 return op->
emitOpError() <<
"aligned variable used more than once";
574 for (
unsigned i = 0; i < (*alignments).size(); ++i) {
575 if (
auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
576 if (intAttr.getValue().sle(0))
577 return op->
emitOpError() <<
"alignment should be greater than 0";
579 return op->
emitOpError() <<
"expected integer alignment";
596 if (parser.parseOperand(alignedVars.emplace_back()) ||
597 parser.parseColonType(alignedTypes.emplace_back()) ||
598 parser.parseArrow() ||
599 parser.parseAttribute(alignmentVec.emplace_back())) {
606 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
613 std::optional<ArrayAttr> alignments) {
614 for (
unsigned i = 0; i < alignedVars.size(); ++i) {
617 p << alignedVars[i] <<
" : " << alignedVars[i].
getType();
618 p <<
" -> " << (*alignments)[i];
629 if (modifiers.size() > 2)
631 for (
const auto &mod : modifiers) {
634 auto symbol = symbolizeScheduleModifier(mod);
637 <<
" unknown modifier type: " << mod;
642 if (modifiers.size() == 1) {
643 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
644 modifiers.push_back(modifiers[0]);
645 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
647 }
else if (modifiers.size() == 2) {
650 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
651 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
653 <<
" incorrect modifier order";
669 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
670 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
675 std::optional<mlir::omp::ClauseScheduleKind> schedule =
676 symbolizeClauseScheduleKind(keyword);
680 scheduleAttr = ClauseScheduleKindAttr::get(parser.
getContext(), *schedule);
682 case ClauseScheduleKind::Static:
683 case ClauseScheduleKind::Dynamic:
684 case ClauseScheduleKind::Guided:
690 chunkSize = std::nullopt;
693 case ClauseScheduleKind::Auto:
694 case ClauseScheduleKind::Runtime:
695 case ClauseScheduleKind::Distribute:
696 chunkSize = std::nullopt;
705 modifiers.push_back(mod);
711 if (!modifiers.empty()) {
713 if (std::optional<ScheduleModifier> mod =
714 symbolizeScheduleModifier(modifiers[0])) {
715 scheduleMod = ScheduleModifierAttr::get(parser.
getContext(), *mod);
717 return parser.
emitError(loc,
"invalid schedule modifier");
720 if (modifiers.size() > 1) {
721 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
731 ClauseScheduleKindAttr scheduleKind,
732 ScheduleModifierAttr scheduleMod,
733 UnitAttr scheduleSimd,
Value scheduleChunk,
734 Type scheduleChunkType) {
735 p << stringifyClauseScheduleKind(scheduleKind.getValue());
737 p <<
" = " << scheduleChunk <<
" : " << scheduleChunk.
getType();
739 p <<
", " << stringifyScheduleModifier(scheduleMod.getValue());
751 ClauseOrderKindAttr &order,
752 OrderModifierAttr &orderMod) {
757 if (std::optional<OrderModifier> enumValue =
758 symbolizeOrderModifier(enumStr)) {
759 orderMod = OrderModifierAttr::get(parser.
getContext(), *enumValue);
766 if (std::optional<ClauseOrderKind> enumValue =
767 symbolizeClauseOrderKind(enumStr)) {
768 order = ClauseOrderKindAttr::get(parser.
getContext(), *enumValue);
771 return parser.
emitError(loc,
"invalid clause value: '") << enumStr <<
"'";
775 ClauseOrderKindAttr order,
776 OrderModifierAttr orderMod) {
778 p << stringifyOrderModifier(orderMod.getValue()) <<
":";
780 p << stringifyClauseOrderKind(order.getValue());
783template <
typename ClauseTypeAttr,
typename ClauseType>
786 std::optional<OpAsmParser::UnresolvedOperand> &operand,
788 std::optional<ClauseType> (*symbolizeClause)(StringRef),
789 StringRef clauseName) {
792 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
793 prescriptiveness = ClauseTypeAttr::get(parser.
getContext(), *enumValue);
798 <<
"invalid " << clauseName <<
" modifier : '" << enumStr <<
"'";
808 <<
"expected " << clauseName <<
" operand";
811 if (operand.has_value()) {
819template <
typename ClauseTypeAttr,
typename ClauseType>
822 ClauseTypeAttr prescriptiveness,
Value operand,
824 StringRef (*stringifyClauseType)(ClauseType)) {
826 if (prescriptiveness)
827 p << stringifyClauseType(prescriptiveness.getValue()) <<
", ";
830 p << operand <<
": " << operandType;
840 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
841 Type &grainsizeType) {
843 parser, grainsizeMod, grainsize, grainsizeType,
844 &symbolizeClauseGrainsizeType,
"grainsize");
848 ClauseGrainsizeTypeAttr grainsizeMod,
851 p, op, grainsizeMod, grainsize, grainsizeType,
852 &stringifyClauseGrainsizeType);
862 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
863 Type &numTasksType) {
865 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
870 ClauseNumTasksTypeAttr numTasksMod,
873 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
882 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
883 SmallVectorImpl<Type> &types;
884 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
885 SmallVectorImpl<Type> &types)
886 : vars(vars), types(types) {}
888struct PrivateParseArgs {
889 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
890 llvm::SmallVectorImpl<Type> &types;
892 UnitAttr &needsBarrier;
894 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
895 SmallVectorImpl<Type> &types,
ArrayAttr &syms,
896 UnitAttr &needsBarrier,
898 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
899 mapIndices(mapIndices) {}
902struct ReductionParseArgs {
903 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
904 SmallVectorImpl<Type> &types;
907 ReductionModifierAttr *modifier;
908 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
910 ArrayAttr &syms, ReductionModifierAttr *mod =
nullptr)
911 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
914struct AllRegionParseArgs {
915 std::optional<MapParseArgs> hasDeviceAddrArgs;
916 std::optional<MapParseArgs> hostEvalArgs;
917 std::optional<ReductionParseArgs> inReductionArgs;
918 std::optional<MapParseArgs> mapArgs;
919 std::optional<PrivateParseArgs> privateArgs;
920 std::optional<ReductionParseArgs> reductionArgs;
921 std::optional<ReductionParseArgs> taskReductionArgs;
922 std::optional<MapParseArgs> useDeviceAddrArgs;
923 std::optional<MapParseArgs> useDevicePtrArgs;
928 return "private_barrier";
938 ReductionModifierAttr *modifier =
nullptr,
939 UnitAttr *needsBarrier =
nullptr) {
943 unsigned regionArgOffset = regionPrivateArgs.size();
953 std::optional<ReductionModifier> enumValue =
954 symbolizeReductionModifier(enumStr);
955 if (!enumValue.has_value())
957 *modifier = ReductionModifierAttr::get(parser.
getContext(), *enumValue);
964 isByRefVec.push_back(
965 parser.parseOptionalKeyword(
"byref").succeeded());
967 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
970 if (parser.parseOperand(operands.emplace_back()) ||
971 parser.parseArrow() ||
972 parser.parseArgument(regionPrivateArgs.emplace_back()))
976 if (parser.parseOptionalLSquare().succeeded()) {
977 if (parser.parseKeyword(
"map_idx") || parser.parseEqual() ||
978 parser.parseInteger(mapIndicesVec.emplace_back()) ||
979 parser.parseRSquare())
982 mapIndicesVec.push_back(-1);
994 if (parser.parseType(types.emplace_back()))
1001 if (operands.size() != types.size())
1010 *needsBarrier = mlir::UnitAttr::get(parser.
getContext());
1013 auto *argsBegin = regionPrivateArgs.begin();
1015 argsBegin + regionArgOffset + types.size());
1016 for (
auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1022 *symbols = ArrayAttr::get(parser.
getContext(), symbolAttrs);
1025 if (!mapIndicesVec.empty())
1038 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1053 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1059 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1060 &privateArgs->syms, privateArgs->mapIndices,
nullptr,
1061 nullptr, &privateArgs->needsBarrier)))
1070 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1075 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1076 &reductionArgs->syms,
nullptr, &reductionArgs->byref,
1077 reductionArgs->modifier)))
1084 AllRegionParseArgs args) {
1088 args.hasDeviceAddrArgs)))
1090 <<
"invalid `has_device_addr` format";
1093 args.hostEvalArgs)))
1095 <<
"invalid `host_eval` format";
1098 args.inReductionArgs)))
1100 <<
"invalid `in_reduction` format";
1105 <<
"invalid `map_entries` format";
1110 <<
"invalid `private` format";
1113 args.reductionArgs)))
1115 <<
"invalid `reduction` format";
1118 args.taskReductionArgs)))
1120 <<
"invalid `task_reduction` format";
1123 args.useDeviceAddrArgs)))
1125 <<
"invalid `use_device_addr` format";
1128 args.useDevicePtrArgs)))
1130 <<
"invalid `use_device_addr` format";
1132 return parser.
parseRegion(region, entryBlockArgs);
1151 AllRegionParseArgs args;
1152 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1153 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1154 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1155 inReductionByref, inReductionSyms);
1156 args.mapArgs.emplace(mapVars, mapTypes);
1157 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1158 privateNeedsBarrier, &privateMaps);
1169 UnitAttr &privateNeedsBarrier) {
1170 AllRegionParseArgs args;
1171 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1172 inReductionByref, inReductionSyms);
1173 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1174 privateNeedsBarrier);
1185 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1189 AllRegionParseArgs args;
1190 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1191 inReductionByref, inReductionSyms);
1192 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1193 privateNeedsBarrier);
1194 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1195 reductionSyms, &reductionMod);
1203 UnitAttr &privateNeedsBarrier) {
1204 AllRegionParseArgs args;
1205 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1206 privateNeedsBarrier);
1214 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1218 AllRegionParseArgs args;
1219 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1220 privateNeedsBarrier);
1221 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1222 reductionSyms, &reductionMod);
1231 AllRegionParseArgs args;
1232 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1233 taskReductionByref, taskReductionSyms);
1243 AllRegionParseArgs args;
1244 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1245 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1254struct MapPrintArgs {
1259struct PrivatePrintArgs {
1263 UnitAttr needsBarrier;
1267 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1268 mapIndices(mapIndices) {}
1270struct ReductionPrintArgs {
1275 ReductionModifierAttr modifier;
1277 ArrayAttr syms, ReductionModifierAttr mod =
nullptr)
1278 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1280struct AllRegionPrintArgs {
1281 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1282 std::optional<MapPrintArgs> hostEvalArgs;
1283 std::optional<ReductionPrintArgs> inReductionArgs;
1284 std::optional<MapPrintArgs> mapArgs;
1285 std::optional<PrivatePrintArgs> privateArgs;
1286 std::optional<ReductionPrintArgs> reductionArgs;
1287 std::optional<ReductionPrintArgs> taskReductionArgs;
1288 std::optional<MapPrintArgs> useDeviceAddrArgs;
1289 std::optional<MapPrintArgs> useDevicePtrArgs;
1298 ReductionModifierAttr modifier =
nullptr, UnitAttr needsBarrier =
nullptr) {
1299 if (argsSubrange.empty())
1302 p << clauseName <<
"(";
1305 p <<
"mod: " << stringifyReductionModifier(modifier.getValue()) <<
", ";
1309 symbols = ArrayAttr::get(ctx, values);
1322 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1323 mapIndices.asArrayRef(),
1324 byref.asArrayRef()),
1326 auto [op, arg, sym, map, isByRef] = t;
1332 p << op <<
" -> " << arg;
1335 p <<
" [map_idx=" << map <<
"]";
1338 llvm::interleaveComma(types, p);
1346 StringRef clauseName,
ValueRange argsSubrange,
1347 std::optional<MapPrintArgs> mapArgs) {
1354 StringRef clauseName,
ValueRange argsSubrange,
1355 std::optional<PrivatePrintArgs> privateArgs) {
1358 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1359 privateArgs->syms, privateArgs->mapIndices,
nullptr,
1360 nullptr, privateArgs->needsBarrier);
1366 std::optional<ReductionPrintArgs> reductionArgs) {
1369 reductionArgs->vars, reductionArgs->types,
1370 reductionArgs->syms,
nullptr,
1371 reductionArgs->byref, reductionArgs->modifier);
1375 const AllRegionPrintArgs &args) {
1376 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1380 iface.getHasDeviceAddrBlockArgs(),
1381 args.hasDeviceAddrArgs);
1385 args.inReductionArgs);
1391 args.reductionArgs);
1393 iface.getTaskReductionBlockArgs(),
1394 args.taskReductionArgs);
1396 iface.getUseDeviceAddrBlockArgs(),
1397 args.useDeviceAddrArgs);
1399 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1415 AllRegionPrintArgs args;
1416 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1417 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1418 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1419 inReductionByref, inReductionSyms);
1420 args.mapArgs.emplace(mapVars, mapTypes);
1421 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1422 privateNeedsBarrier, privateMaps);
1430 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1431 AllRegionPrintArgs args;
1432 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1433 inReductionByref, inReductionSyms);
1434 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1435 privateNeedsBarrier,
1444 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1445 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1448 AllRegionPrintArgs args;
1449 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1450 inReductionByref, inReductionSyms);
1451 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1452 privateNeedsBarrier,
1454 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1455 reductionSyms, reductionMod);
1462 UnitAttr privateNeedsBarrier) {
1463 AllRegionPrintArgs args;
1464 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1465 privateNeedsBarrier,
1473 ReductionModifierAttr reductionMod,
ValueRange reductionVars,
1476 AllRegionPrintArgs args;
1477 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1478 privateNeedsBarrier,
1480 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1481 reductionSyms, reductionMod);
1491 AllRegionPrintArgs args;
1492 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1493 taskReductionByref, taskReductionSyms);
1503 AllRegionPrintArgs args;
1504 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1505 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1509template <
typename ParsePrefixFn>
1518 if (failed(parsePrefix()))
1526 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1527 iteratedVars.push_back(v);
1528 iteratedTypes.push_back(ty);
1530 plainVars.push_back(v);
1531 plainTypes.push_back(ty);
1537template <
typename Pr
intPrefixFn>
1541 PrintPrefixFn &&printPrefixForPlain,
1542 PrintPrefixFn &&printPrefixForIterated) {
1549 p << v <<
" : " << t;
1553 for (
unsigned i = 0; i < iteratedVars.size(); ++i)
1554 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1555 for (
unsigned i = 0; i < plainVars.size(); ++i)
1556 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1564 if (!reductionVars.empty()) {
1565 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1567 <<
"expected as many reduction symbol references "
1568 "as reduction variables";
1569 if (reductionByref && reductionByref->size() != reductionVars.size())
1570 return op->
emitError() <<
"expected as many reduction variable by "
1571 "reference attributes as reduction variables";
1574 return op->
emitOpError() <<
"unexpected reduction symbol references";
1581 for (
auto args : llvm::zip(reductionVars, *reductionSyms)) {
1582 Value accum = std::get<0>(args);
1584 if (!accumulators.insert(accum).second)
1585 return op->
emitOpError() <<
"accumulator variable used more than once";
1588 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1592 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1593 <<
" to point to a reduction declaration";
1595 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1597 <<
"expected accumulator (" << varType
1598 <<
") to be the same type as reduction declaration ("
1599 << decl.getAccumulatorType() <<
")";
1618 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1619 parser.parseArrow() ||
1620 parser.parseAttribute(symsVec.emplace_back()) ||
1621 parser.parseColonType(copyprivateTypes.emplace_back()))
1627 copyprivateSyms = ArrayAttr::get(parser.
getContext(), syms);
1635 std::optional<ArrayAttr> copyprivateSyms) {
1636 if (!copyprivateSyms.has_value())
1638 llvm::interleaveComma(
1639 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1640 [&](
const auto &args) {
1641 p << std::get<0>(args) <<
" -> " << std::get<1>(args) <<
" : "
1642 << std::get<2>(args);
1649 std::optional<ArrayAttr> copyprivateSyms) {
1650 size_t copyprivateSymsSize =
1651 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1652 if (copyprivateSymsSize != copyprivateVars.size())
1653 return op->
emitOpError() <<
"inconsistent number of copyprivate vars (= "
1654 << copyprivateVars.size()
1655 <<
") and functions (= " << copyprivateSymsSize
1656 <<
"), both must be equal";
1657 if (!copyprivateSyms.has_value())
1660 for (
auto copyprivateVarAndSym :
1661 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1663 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1664 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1666 if (mlir::func::FuncOp mlirFuncOp =
1669 funcOp = mlirFuncOp;
1670 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1673 funcOp = llvmFuncOp;
1675 auto getNumArguments = [&] {
1676 return std::visit([](
auto &f) {
return f.getNumArguments(); }, *funcOp);
1679 auto getArgumentType = [&](
unsigned i) {
1680 return std::visit([i](
auto &f) {
return f.getArgumentTypes()[i]; },
1685 return op->
emitOpError() <<
"expected symbol reference " << symbolRef
1686 <<
" to point to a copy function";
1688 if (getNumArguments() != 2)
1690 <<
"expected copy function " << symbolRef <<
" to have 2 operands";
1692 Type argTy = getArgumentType(0);
1693 if (argTy != getArgumentType(1))
1694 return op->
emitOpError() <<
"expected copy function " << symbolRef
1695 <<
" arguments to have the same type";
1697 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1698 if (argTy != varType)
1700 <<
"expected copy function arguments' type (" << argTy
1701 <<
") to be the same as copyprivate variable's type (" << varType
1726 OpAsmParser::UnresolvedOperand operand;
1728 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1729 parser.parseOperand(operand) || parser.parseColonType(ty))
1731 std::optional<ClauseTaskDepend> keywordDepend =
1732 symbolizeClauseTaskDepend(keyword);
1736 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1737 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1738 iteratedVars.push_back(operand);
1739 iteratedTypes.push_back(ty);
1740 iterKindsVec.push_back(kindAttr);
1742 dependVars.push_back(operand);
1743 dependTypes.push_back(ty);
1744 kindsVec.push_back(kindAttr);
1750 dependKinds = ArrayAttr::get(parser.
getContext(), kinds);
1752 iteratedKinds = ArrayAttr::get(parser.
getContext(), iterKinds);
1759 std::optional<ArrayAttr> dependKinds,
1762 std::optional<ArrayAttr> iteratedKinds) {
1765 std::optional<ArrayAttr> kinds) {
1766 for (
unsigned i = 0, e = vars.size(); i < e; ++i) {
1769 p << stringifyClauseTaskDepend(
1770 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1772 <<
" -> " << vars[i] <<
" : " << types[i];
1776 printEntries(dependVars, dependTypes, dependKinds);
1777 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1782 std::optional<ArrayAttr> dependKinds,
1784 std::optional<ArrayAttr> iteratedKinds,
1786 if (!dependVars.empty()) {
1787 if (!dependKinds || dependKinds->size() != dependVars.size())
1788 return op->
emitOpError() <<
"expected as many depend values"
1789 " as depend variables";
1791 if (dependKinds && !dependKinds->empty())
1792 return op->
emitOpError() <<
"unexpected depend values";
1795 if (!iteratedVars.empty()) {
1796 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1797 return op->
emitOpError() <<
"expected as many depend iterated values"
1798 " as depend iterated variables";
1800 if (iteratedKinds && !iteratedKinds->empty())
1801 return op->
emitOpError() <<
"unexpected depend iterated values";
1816 IntegerAttr &hintAttr) {
1817 StringRef hintKeyword;
1823 auto parseKeyword = [&]() -> ParseResult {
1826 if (hintKeyword ==
"uncontended")
1828 else if (hintKeyword ==
"contended")
1830 else if (hintKeyword ==
"nonspeculative")
1832 else if (hintKeyword ==
"speculative")
1836 << hintKeyword <<
" is not a valid hint";
1847 IntegerAttr hintAttr) {
1848 int64_t hint = hintAttr.getInt();
1856 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1858 bool uncontended = bitn(hint, 0);
1859 bool contended = bitn(hint, 1);
1860 bool nonspeculative = bitn(hint, 2);
1861 bool speculative = bitn(hint, 3);
1865 hints.push_back(
"uncontended");
1867 hints.push_back(
"contended");
1869 hints.push_back(
"nonspeculative");
1871 hints.push_back(
"speculative");
1873 llvm::interleaveComma(hints, p);
1880 auto bitn = [](
int value,
int n) ->
bool {
return value & (1 << n); };
1882 bool uncontended = bitn(hint, 0);
1883 bool contended = bitn(hint, 1);
1884 bool nonspeculative = bitn(hint, 2);
1885 bool speculative = bitn(hint, 3);
1887 if (uncontended && contended)
1888 return op->
emitOpError() <<
"the hints omp_sync_hint_uncontended and "
1889 "omp_sync_hint_contended cannot be combined";
1890 if (nonspeculative && speculative)
1891 return op->
emitOpError() <<
"the hints omp_sync_hint_nonspeculative and "
1892 "omp_sync_hint_speculative cannot be combined.";
1903 return (value & flag) == flag;
1911static ParseResult parseMapClause(
OpAsmParser &parser,
1912 ClauseMapFlagsAttr &mapType) {
1913 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1916 auto parseTypeAndMod = [&]() -> ParseResult {
1917 StringRef mapTypeMod;
1921 if (mapTypeMod ==
"always")
1922 mapTypeBits |= ClauseMapFlags::always;
1924 if (mapTypeMod ==
"implicit")
1925 mapTypeBits |= ClauseMapFlags::implicit;
1927 if (mapTypeMod ==
"ompx_hold")
1928 mapTypeBits |= ClauseMapFlags::ompx_hold;
1930 if (mapTypeMod ==
"close")
1931 mapTypeBits |= ClauseMapFlags::close;
1933 if (mapTypeMod ==
"present")
1934 mapTypeBits |= ClauseMapFlags::present;
1936 if (mapTypeMod ==
"to")
1937 mapTypeBits |= ClauseMapFlags::to;
1939 if (mapTypeMod ==
"from")
1940 mapTypeBits |= ClauseMapFlags::from;
1942 if (mapTypeMod ==
"tofrom")
1943 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1945 if (mapTypeMod ==
"delete")
1946 mapTypeBits |= ClauseMapFlags::del;
1948 if (mapTypeMod ==
"storage")
1949 mapTypeBits |= ClauseMapFlags::storage;
1951 if (mapTypeMod ==
"return_param")
1952 mapTypeBits |= ClauseMapFlags::return_param;
1954 if (mapTypeMod ==
"private")
1955 mapTypeBits |= ClauseMapFlags::priv;
1957 if (mapTypeMod ==
"literal")
1958 mapTypeBits |= ClauseMapFlags::literal;
1960 if (mapTypeMod ==
"attach")
1961 mapTypeBits |= ClauseMapFlags::attach;
1963 if (mapTypeMod ==
"attach_always")
1964 mapTypeBits |= ClauseMapFlags::attach_always;
1966 if (mapTypeMod ==
"attach_never")
1967 mapTypeBits |= ClauseMapFlags::attach_never;
1969 if (mapTypeMod ==
"attach_auto")
1970 mapTypeBits |= ClauseMapFlags::attach_auto;
1972 if (mapTypeMod ==
"ref_ptr")
1973 mapTypeBits |= ClauseMapFlags::ref_ptr;
1975 if (mapTypeMod ==
"ref_ptee")
1976 mapTypeBits |= ClauseMapFlags::ref_ptee;
1978 if (mapTypeMod ==
"ref_ptr_ptee")
1979 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1981 if (mapTypeMod ==
"is_device_ptr")
1982 mapTypeBits |= ClauseMapFlags::is_device_ptr;
1999 ClauseMapFlagsAttr mapType) {
2001 ClauseMapFlags mapFlags = mapType.getValue();
2006 mapTypeStrs.push_back(
"always");
2008 mapTypeStrs.push_back(
"implicit");
2010 mapTypeStrs.push_back(
"ompx_hold");
2012 mapTypeStrs.push_back(
"close");
2014 mapTypeStrs.push_back(
"present");
2023 mapTypeStrs.push_back(
"tofrom");
2025 mapTypeStrs.push_back(
"from");
2027 mapTypeStrs.push_back(
"to");
2030 mapTypeStrs.push_back(
"delete");
2032 mapTypeStrs.push_back(
"return_param");
2034 mapTypeStrs.push_back(
"storage");
2036 mapTypeStrs.push_back(
"private");
2038 mapTypeStrs.push_back(
"literal");
2040 mapTypeStrs.push_back(
"attach");
2042 mapTypeStrs.push_back(
"attach_always");
2044 mapTypeStrs.push_back(
"attach_never");
2046 mapTypeStrs.push_back(
"attach_auto");
2048 mapTypeStrs.push_back(
"ref_ptr");
2050 mapTypeStrs.push_back(
"ref_ptee");
2052 mapTypeStrs.push_back(
"ref_ptr_ptee");
2054 mapTypeStrs.push_back(
"is_device_ptr");
2055 if (mapFlags == ClauseMapFlags::none)
2056 mapTypeStrs.push_back(
"none");
2058 for (
unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2059 p << mapTypeStrs[i];
2060 if (i + 1 < mapTypeStrs.size()) {
2066static ParseResult parseMembersIndex(
OpAsmParser &parser,
2070 auto parseIndices = [&]() -> ParseResult {
2075 APInt(64, value,
false)));
2089 memberIdxs.push_back(ArrayAttr::get(parser.
getContext(), values));
2093 if (!memberIdxs.empty())
2094 membersIdx = ArrayAttr::get(parser.
getContext(), memberIdxs);
2104 llvm::interleaveComma(membersIdx, p, [&p](
Attribute v) {
2106 auto memberIdx = cast<ArrayAttr>(v);
2107 llvm::interleaveComma(memberIdx.getValue(), p, [&p](
Attribute v2) {
2108 p << cast<IntegerAttr>(v2).getInt();
2115 VariableCaptureKindAttr mapCaptureType) {
2116 std::string typeCapStr;
2117 llvm::raw_string_ostream typeCap(typeCapStr);
2118 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2120 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2121 typeCap <<
"ByCopy";
2122 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2123 typeCap <<
"VLAType";
2124 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2130 VariableCaptureKindAttr &mapCaptureType) {
2131 StringRef mapCaptureKey;
2135 if (mapCaptureKey ==
"This")
2136 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2137 parser.
getContext(), mlir::omp::VariableCaptureKind::This);
2138 if (mapCaptureKey ==
"ByRef")
2139 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2140 parser.
getContext(), mlir::omp::VariableCaptureKind::ByRef);
2141 if (mapCaptureKey ==
"ByCopy")
2142 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2143 parser.
getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2144 if (mapCaptureKey ==
"VLAType")
2145 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2146 parser.
getContext(), mlir::omp::VariableCaptureKind::VLAType);
2155 for (
auto mapOp : mapVars) {
2156 if (!mapOp.getDefiningOp())
2159 if (
auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2160 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2163 bool from =
mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2166 bool always =
mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2167 bool close =
mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2168 bool implicit =
mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2170 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2172 "to, from, tofrom and alloc map types are permitted");
2174 if (isa<TargetEnterDataOp>(op) && (from || del))
2175 return emitError(op->
getLoc(),
"to and alloc map types are permitted");
2177 if (isa<TargetExitDataOp>(op) && to)
2179 "from, release and delete map types are permitted");
2181 if (isa<TargetUpdateOp>(op)) {
2184 "at least one of to or from map types must be "
2185 "specified, other map types are not permitted");
2190 "at least one of to or from map types must be "
2191 "specified, other map types are not permitted");
2194 auto updateVar = mapInfoOp.getVarPtr();
2196 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2197 (from && updateToVars.contains(updateVar))) {
2200 "either to or from map types can be specified, not both");
2203 if (always || close || implicit) {
2206 "present, mapper and iterator map type modifiers are permitted");
2209 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2211 }
else if (!isa<DeclareMapperInfoOp>(op)) {
2213 "map argument is not a map entry operation");
2221 std::optional<DenseI64ArrayAttr> privateMapIndices =
2222 targetOp.getPrivateMapsAttr();
2225 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2230 if (privateMapIndices.value().size() !=
2231 static_cast<int64_t>(privateVars.size()))
2232 return emitError(targetOp.getLoc(),
"sizes of `private` operand range and "
2233 "`private_maps` attribute mismatch");
2243 StringRef clauseName,
2245 for (
Value var : vars)
2246 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2248 <<
"'" << clauseName
2249 <<
"' arguments must be defined by 'omp.map.info' ops";
2253LogicalResult MapInfoOp::verify() {
2254 if (getMapperId() &&
2256 *
this, getMapperIdAttr())) {
2271 const TargetDataOperands &clauses) {
2272 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2273 clauses.mapVars, clauses.useDeviceAddrVars,
2274 clauses.useDevicePtrVars);
2277LogicalResult TargetDataOp::verify() {
2278 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2279 getUseDeviceAddrVars().empty()) {
2280 return ::emitError(this->getLoc(),
2281 "At least one of map, use_device_ptr_vars, or "
2282 "use_device_addr_vars operand must be present");
2286 getUseDevicePtrVars())))
2290 getUseDeviceAddrVars())))
2300void TargetEnterDataOp::build(
2304 TargetEnterDataOp::build(
2306 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2307 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2311LogicalResult TargetEnterDataOp::verify() {
2312 LogicalResult verifyDependVars =
2314 getDependIteratedKinds(), getDependIterated());
2315 return failed(verifyDependVars) ? verifyDependVars
2326 TargetExitDataOp::build(
2328 clauses.dependVars,
makeArrayAttr(ctx, clauses.dependIteratedKinds),
2329 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2333LogicalResult TargetExitDataOp::verify() {
2334 LogicalResult verifyDependVars =
2336 getDependIteratedKinds(), getDependIterated());
2337 return failed(verifyDependVars) ? verifyDependVars
2348 TargetUpdateOp::build(builder, state,
makeArrayAttr(ctx, clauses.dependKinds),
2351 clauses.dependIterated, clauses.device, clauses.ifExpr,
2352 clauses.mapVars, clauses.nowait);
2355LogicalResult TargetUpdateOp::verify() {
2356 LogicalResult verifyDependVars =
2358 getDependIteratedKinds(), getDependIterated());
2359 return failed(verifyDependVars) ? verifyDependVars
2368 const TargetOperands &clauses) {
2373 builder, state, {}, {}, clauses.bare,
2374 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2375 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2376 clauses.device, clauses.hasDeviceAddrVars, clauses.hostEvalVars,
2379 nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2380 clauses.nowait, clauses.privateVars,
2381 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2382 clauses.threadLimitVars,
2386LogicalResult TargetOp::verify() {
2388 getDependIteratedKinds(),
2389 getDependIterated())))
2393 getHasDeviceAddrVars())))
2402LogicalResult TargetOp::verifyRegions() {
2403 auto teamsOps = getOps<TeamsOp>();
2404 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2405 return emitError(
"target containing multiple 'omp.teams' nested ops");
2408 Operation *capturedOp = getInnermostCapturedOmpOp();
2409 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2410 for (
Value hostEvalArg :
2411 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2413 if (
auto teamsOp = dyn_cast<TeamsOp>(user)) {
2415 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2416 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2417 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2420 return emitOpError() <<
"host_eval argument only legal as 'num_teams' "
2421 "and 'thread_limit' in 'omp.teams'";
2423 if (
auto parallelOp = dyn_cast<ParallelOp>(user)) {
2424 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2425 parallelOp->isAncestor(capturedOp) &&
2426 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2430 <<
"host_eval argument only legal as 'num_threads' in "
2431 "'omp.parallel' when representing target SPMD";
2433 if (
auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2434 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2435 loopNestOp.getOperation() == capturedOp &&
2436 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2437 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2438 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2441 return emitOpError() <<
"host_eval argument only legal as loop bounds "
2442 "and steps in 'omp.loop_nest' when trip count "
2443 "must be evaluated in the host";
2446 return emitOpError() <<
"host_eval argument illegal use in '"
2447 << user->getName() <<
"' operation";
2456 assert(rootOp &&
"expected valid operation");
2473 bool isOmpDialect = op->
getDialect() == ompDialect;
2475 if (!isOmpDialect || !hasRegions)
2482 if (checkSingleMandatoryExec) {
2487 if (successor->isReachable(parentBlock))
2490 for (
Block &block : *parentRegion)
2492 !domInfo.
dominates(parentBlock, &block))
2499 if (&sibling != op && !siblingAllowedFn(&sibling))
2512Operation *TargetOp::getInnermostCapturedOmpOp() {
2513 auto *ompDialect =
getContext()->getLoadedDialect<omp::OpenMPDialect>();
2525 if (
auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2528 memOp.getEffects(effects);
2529 return !llvm::any_of(
2531 return isa<MemoryEffects::Write>(effect.
getEffect()) &&
2532 isa<SideEffects::AutomaticAllocationScopeResource>(
2542 WsloopOp *wsLoopOp) {
2544 if (!teamsOp.getNumTeamsUpperVars().empty())
2548 if (teamsOp.getNumReductionVars())
2550 if (wsLoopOp->getNumReductionVars())
2554 OffloadModuleInterface offloadMod =
2558 auto ompFlags = offloadMod.getFlags();
2561 return ompFlags.getAssumeTeamsOversubscription() &&
2562 ompFlags.getAssumeThreadsOversubscription();
2565TargetRegionFlags TargetOp::getKernelExecFlags(
Operation *capturedOp) {
2570 assert((!capturedOp ||
2571 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2572 "unexpected captured op");
2575 if (!isa_and_present<LoopNestOp>(capturedOp))
2576 return TargetRegionFlags::generic;
2580 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2581 assert(!loopWrappers.empty());
2583 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2584 if (isa<SimdOp>(innermostWrapper))
2585 innermostWrapper = std::next(innermostWrapper);
2587 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2588 if (numWrappers != 1 && numWrappers != 2)
2589 return TargetRegionFlags::generic;
2592 if (numWrappers == 2) {
2593 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2595 return TargetRegionFlags::generic;
2597 innermostWrapper = std::next(innermostWrapper);
2598 if (!isa<DistributeOp>(innermostWrapper))
2599 return TargetRegionFlags::generic;
2602 if (!isa_and_present<ParallelOp>(parallelOp))
2603 return TargetRegionFlags::generic;
2605 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->
getParentOp());
2607 return TargetRegionFlags::generic;
2609 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2610 TargetRegionFlags
result =
2611 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2618 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2620 if (!isa_and_present<TeamsOp>(teamsOp))
2621 return TargetRegionFlags::generic;
2623 if (teamsOp->
getParentOp() != targetOp.getOperation())
2624 return TargetRegionFlags::generic;
2626 if (isa<LoopOp>(innermostWrapper))
2627 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2637 Dialect *ompDialect = targetOp->getDialect();
2641 return sibling && (ompDialect != sibling->
getDialect() ||
2645 TargetRegionFlags
result =
2646 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2651 while (nestedCapture->
getParentOp() != capturedOp)
2654 return isa<ParallelOp>(nestedCapture) ?
result | TargetRegionFlags::spmd
2658 else if (isa<WsloopOp>(innermostWrapper)) {
2660 if (!isa_and_present<ParallelOp>(parallelOp))
2661 return TargetRegionFlags::generic;
2663 if (parallelOp->
getParentOp() == targetOp.getOperation())
2664 return TargetRegionFlags::spmd;
2667 return TargetRegionFlags::generic;
2676 ParallelOp::build(builder, state,
ValueRange(),
2688 const ParallelOperands &clauses) {
2690 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2691 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2693 clauses.privateNeedsBarrier, clauses.procBindKind,
2694 clauses.reductionMod, clauses.reductionVars,
2699template <
typename OpType>
2701 auto privateVars = op.getPrivateVars();
2702 auto privateSyms = op.getPrivateSymsAttr();
2704 if (privateVars.empty() && (privateSyms ==
nullptr || privateSyms.empty()))
2707 auto numPrivateVars = privateVars.size();
2708 auto numPrivateSyms = (privateSyms ==
nullptr) ? 0 : privateSyms.size();
2710 if (numPrivateVars != numPrivateSyms)
2711 return op.emitError() <<
"inconsistent number of private variables and "
2712 "privatizer op symbols, private vars: "
2714 <<
" vs. privatizer op symbols: " << numPrivateSyms;
2716 for (
auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2717 Type varType = std::get<0>(privateVarInfo).getType();
2718 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2719 PrivateClauseOp privatizerOp =
2722 if (privatizerOp ==
nullptr)
2723 return op.emitError() <<
"failed to lookup privatizer op with symbol: '"
2724 << privateSym <<
"'";
2726 Type privatizerType = privatizerOp.getArgType();
2728 if (privatizerType && (varType != privatizerType))
2729 return op.emitError()
2730 <<
"type mismatch between a "
2731 << (privatizerOp.getDataSharingType() ==
2732 DataSharingClauseType::Private
2735 <<
" variable and its privatizer op, var type: " << varType
2736 <<
" vs. privatizer op type: " << privatizerType;
2742LogicalResult ParallelOp::verify() {
2743 if (getAllocateVars().size() != getAllocatorVars().size())
2745 "expected equal sizes for allocate and allocator variables");
2751 getReductionByref());
2754LogicalResult ParallelOp::verifyRegions() {
2755 auto distChildOps = getOps<DistributeOp>();
2756 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2757 if (numDistChildOps > 1)
2759 <<
"multiple 'omp.distribute' nested inside of 'omp.parallel'";
2761 if (numDistChildOps == 1) {
2764 <<
"'omp.composite' attribute missing from composite operation";
2766 auto *ompDialect =
getContext()->getLoadedDialect<OpenMPDialect>();
2767 Operation &distributeOp = **distChildOps.begin();
2769 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2773 return emitError() <<
"unexpected OpenMP operation inside of composite "
2775 << childOp.getName();
2777 }
else if (isComposite()) {
2779 <<
"'omp.composite' attribute present in non-composite operation";
2796 const TeamsOperands &clauses) {
2800 builder, state, clauses.allocateVars, clauses.allocatorVars,
2801 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
2803 nullptr, clauses.reductionMod,
2804 clauses.reductionVars,
2806 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2813 if (numTeamsLower) {
2814 if (numTeamsUpperVars.size() != 1)
2816 "expected exactly one num_teams upper bound when lower bound is "
2820 "expected num_teams upper bound and lower bound to be "
2827LogicalResult TeamsOp::verify() {
2836 return emitError(
"expected to be nested inside of omp.target or not nested "
2837 "in any OpenMP dialect operations");
2841 this->getNumTeamsUpperVars())))
2845 if (getAllocateVars().size() != getAllocatorVars().size())
2847 "expected equal sizes for allocate and allocator variables");
2850 getReductionByref());
2858 return getParentOp().getPrivateVars();
2862 return getParentOp().getReductionVars();
2870 const SectionsOperands &clauses) {
2873 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2876 clauses.reductionMod, clauses.reductionVars,
2881LogicalResult SectionsOp::verify() {
2882 if (getAllocateVars().size() != getAllocatorVars().size())
2884 "expected equal sizes for allocate and allocator variables");
2887 getReductionByref());
2890LogicalResult SectionsOp::verifyRegions() {
2891 for (
auto &inst : *getRegion().begin()) {
2892 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2894 <<
"expected omp.section op or terminator op inside region";
2906 const SingleOperands &clauses) {
2909 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2910 clauses.copyprivateVars,
2911 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2916LogicalResult SingleOp::verify() {
2918 if (getAllocateVars().size() != getAllocatorVars().size())
2920 "expected equal sizes for allocate and allocator variables");
2923 getCopyprivateSyms());
2931 const WorkshareOperands &clauses) {
2932 WorkshareOp::build(builder, state, clauses.nowait);
2939LogicalResult WorkshareLoopWrapperOp::verify() {
2940 if (!(*this)->getParentOfType<WorkshareOp>())
2941 return emitOpError() <<
"must be nested in an omp.workshare";
2945LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2946 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2948 return emitOpError() <<
"expected to be a standalone loop wrapper";
2957LogicalResult LoopWrapperInterface::verifyImpl() {
2961 return emitOpError() <<
"loop wrapper must also have the `NoTerminator` "
2962 "and `SingleBlock` traits";
2965 return emitOpError() <<
"loop wrapper does not contain exactly one region";
2968 if (range_size(region.
getOps()) != 1)
2970 <<
"loop wrapper does not contain exactly one nested op";
2973 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2974 return emitOpError() <<
"nested in loop wrapper is not another loop "
2975 "wrapper or `omp.loop_nest`";
2985 const LoopOperands &clauses) {
2988 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2990 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2991 clauses.reductionMod, clauses.reductionVars,
2996LogicalResult LoopOp::verify() {
2998 getReductionByref());
3001LogicalResult LoopOp::verifyRegions() {
3002 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3004 return emitOpError() <<
"expected to be a standalone loop wrapper";
3015 build(builder, state, {}, {},
3018 false,
nullptr,
nullptr,
3019 nullptr, {},
nullptr,
3030 const WsloopOperands &clauses) {
3035 {}, {}, clauses.linearVars,
3036 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3037 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3038 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3039 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3041 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3042 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3045LogicalResult WsloopOp::verify() {
3049 if (getLinearVars().size() &&
3050 getLinearVarTypes().value().size() != getLinearVars().size())
3051 return emitError() <<
"Ill-formed type attributes for linear variables";
3053 getReductionByref());
3056LogicalResult WsloopOp::verifyRegions() {
3057 bool isCompositeChildLeaf =
3058 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3060 if (LoopWrapperInterface nested = getNestedWrapper()) {
3063 <<
"'omp.composite' attribute missing from composite wrapper";
3067 if (!isa<SimdOp>(nested))
3068 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3070 }
else if (isComposite() && !isCompositeChildLeaf) {
3072 <<
"'omp.composite' attribute present in non-composite wrapper";
3073 }
else if (!isComposite() && isCompositeChildLeaf) {
3075 <<
"'omp.composite' attribute missing from composite wrapper";
3086 const SimdOperands &clauses) {
3088 SimdOp::build(builder, state, clauses.alignedVars,
3090 clauses.linearVars, clauses.linearStepVars,
3091 clauses.linearVarTypes, clauses.linearModifiers,
3092 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3093 clauses.privateVars,
makeArrayAttr(ctx, clauses.privateSyms),
3094 clauses.privateNeedsBarrier, clauses.reductionMod,
3095 clauses.reductionVars,
3101LogicalResult SimdOp::verify() {
3102 if (getSimdlen().has_value() && getSafelen().has_value() &&
3103 getSimdlen().value() > getSafelen().value())
3105 <<
"simdlen clause and safelen clause are both present, but the "
3106 "simdlen value is not less than or equal to safelen value";
3118 bool isCompositeChildLeaf =
3119 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3121 if (!isComposite() && isCompositeChildLeaf)
3123 <<
"'omp.composite' attribute missing from composite wrapper";
3125 if (isComposite() && !isCompositeChildLeaf)
3127 <<
"'omp.composite' attribute present in non-composite wrapper";
3131 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3133 for (
const Attribute &sym : *privateSyms) {
3134 auto symRef = cast<SymbolRefAttr>(sym);
3135 omp::PrivateClauseOp privatizer =
3137 getOperation(), symRef);
3139 return emitError() <<
"Cannot find privatizer '" << symRef <<
"'";
3140 if (privatizer.getDataSharingType() ==
3141 DataSharingClauseType::FirstPrivate)
3142 return emitError() <<
"FIRSTPRIVATE cannot be used with SIMD";
3146 if (getLinearVars().size() &&
3147 getLinearVarTypes().value().size() != getLinearVars().size())
3148 return emitError() <<
"Ill-formed type attributes for linear variables";
3152LogicalResult SimdOp::verifyRegions() {
3153 if (getNestedWrapper())
3154 return emitOpError() <<
"must wrap an 'omp.loop_nest' directly";
3164 const DistributeOperands &clauses) {
3165 DistributeOp::build(builder, state, clauses.allocateVars,
3166 clauses.allocatorVars, clauses.distScheduleStatic,
3167 clauses.distScheduleChunkSize, clauses.order,
3168 clauses.orderMod, clauses.privateVars,
3170 clauses.privateNeedsBarrier);
3173LogicalResult DistributeOp::verify() {
3174 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3176 "dist_schedule_static being present";
3178 if (getAllocateVars().size() != getAllocatorVars().size())
3180 "expected equal sizes for allocate and allocator variables");
3185LogicalResult DistributeOp::verifyRegions() {
3186 if (LoopWrapperInterface nested = getNestedWrapper()) {
3189 <<
"'omp.composite' attribute missing from composite wrapper";
3192 if (isa<WsloopOp>(nested)) {
3194 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3195 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3196 return emitError() <<
"an 'omp.wsloop' nested wrapper is only allowed "
3197 "when a composite 'omp.parallel' is the direct "
3200 }
else if (!isa<SimdOp>(nested))
3201 return emitError() <<
"only supported nested wrappers are 'omp.simd' and "
3203 }
else if (isComposite()) {
3205 <<
"'omp.composite' attribute present in non-composite wrapper";
3215LogicalResult DeclareMapperInfoOp::verify() {
3219LogicalResult DeclareMapperOp::verifyRegions() {
3220 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3221 getRegion().getBlocks().front().getTerminator()))
3222 return emitOpError() <<
"expected terminator to be a DeclareMapperInfoOp";
3231LogicalResult DeclareReductionOp::verifyRegions() {
3232 if (!getAllocRegion().empty()) {
3233 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3234 if (yieldOp.getResults().size() != 1 ||
3235 yieldOp.getResults().getTypes()[0] !=
getType())
3236 return emitOpError() <<
"expects alloc region to yield a value "
3237 "of the reduction type";
3241 if (getInitializerRegion().empty())
3242 return emitOpError() <<
"expects non-empty initializer region";
3243 Block &initializerEntryBlock = getInitializerRegion().
front();
3246 if (!getAllocRegion().empty())
3247 return emitOpError() <<
"expects two arguments to the initializer region "
3248 "when an allocation region is used";
3250 if (getAllocRegion().empty())
3251 return emitOpError() <<
"expects one argument to the initializer region "
3252 "when no allocation region is used";
3255 <<
"expects one or two arguments to the initializer region";
3259 if (arg.getType() !=
getType())
3260 return emitOpError() <<
"expects initializer region argument to match "
3261 "the reduction type";
3263 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3264 if (yieldOp.getResults().size() != 1 ||
3265 yieldOp.getResults().getTypes()[0] !=
getType())
3266 return emitOpError() <<
"expects initializer region to yield a value "
3267 "of the reduction type";
3270 if (getReductionRegion().empty())
3271 return emitOpError() <<
"expects non-empty reduction region";
3272 Block &reductionEntryBlock = getReductionRegion().
front();
3277 return emitOpError() <<
"expects reduction region with two arguments of "
3278 "the reduction type";
3279 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3280 if (yieldOp.getResults().size() != 1 ||
3281 yieldOp.getResults().getTypes()[0] !=
getType())
3282 return emitOpError() <<
"expects reduction region to yield a value "
3283 "of the reduction type";
3286 if (!getAtomicReductionRegion().empty()) {
3287 Block &atomicReductionEntryBlock = getAtomicReductionRegion().
front();
3291 return emitOpError() <<
"expects atomic reduction region with two "
3292 "arguments of the same type";
3293 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3296 (ptrType.getElementType() && ptrType.getElementType() !=
getType()))
3297 return emitOpError() <<
"expects atomic reduction region arguments to "
3298 "be accumulators containing the reduction type";
3301 if (getCleanupRegion().empty())
3303 Block &cleanupEntryBlock = getCleanupRegion().
front();
3306 return emitOpError() <<
"expects cleanup region with one argument "
3307 "of the reduction type";
3317 const TaskOperands &clauses) {
3320 builder, state, clauses.iterated, clauses.affinityVars,
3321 clauses.allocateVars, clauses.allocatorVars,
3322 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3323 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3324 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3326 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3327 clauses.priority, clauses.privateVars,
3329 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3332LogicalResult TaskOp::verify() {
3333 LogicalResult verifyDependVars =
3335 getDependIteratedKinds(), getDependIterated());
3336 return failed(verifyDependVars)
3339 getInReductionVars(),
3340 getInReductionByref());
3348 const TaskgroupOperands &clauses) {
3350 TaskgroupOp::build(builder, state, clauses.allocateVars,
3351 clauses.allocatorVars, clauses.taskReductionVars,
3356LogicalResult TaskgroupOp::verify() {
3358 getTaskReductionVars(),
3359 getTaskReductionByref());
3367 const TaskloopContextOperands &clauses) {
3369 TaskloopContextOp::build(
3370 builder, state, clauses.allocateVars, clauses.allocatorVars,
3371 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3372 clauses.inReductionVars,
3374 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3375 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3376 clauses.privateVars,
3378 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3383TaskloopWrapperOp TaskloopContextOp::getLoopOp() {
3384 return cast<TaskloopWrapperOp>(
3386 return isa<TaskloopWrapperOp>(op);
3390LogicalResult TaskloopContextOp::verify() {
3391 if (getAllocateVars().size() != getAllocatorVars().size())
3393 "expected equal sizes for allocate and allocator variables");
3395 getReductionVars(), getReductionByref())) ||
3397 getInReductionVars(),
3398 getInReductionByref())))
3401 if (!getReductionVars().empty() && getNogroup())
3402 return emitError(
"if a reduction clause is present on the taskloop "
3403 "directive, the nogroup clause must not be specified");
3404 for (
auto var : getReductionVars()) {
3405 if (llvm::is_contained(getInReductionVars(), var))
3406 return emitError(
"the same list item cannot appear in both a reduction "
3407 "and an in_reduction clause");
3410 if (getGrainsize() && getNumTasks()) {
3412 "the grainsize clause and num_tasks clause are mutually exclusive and "
3413 "may not appear on the same taskloop directive");
3419LogicalResult TaskloopContextOp::verifyRegions() {
3420 Region ®ion = getRegion();
3422 return emitOpError() <<
"expected non-empty region";
3425 return isa<TaskloopWrapperOp>(op);
3429 <<
"expected exactly 1 TaskloopWrapperOp directly nested in "
3431 << count <<
" were found";
3432 TaskloopWrapperOp loopWrapperOp = getLoopOp();
3434 auto loopNestOp = dyn_cast<LoopNestOp>(loopWrapperOp.getWrappedLoop());
3440 auto isDefinedInTaskloopContext = [&](
Value value) {
3442 return region.
isAncestor(value.getParentRegion());
3445 return llvm::any_of(range, isDefinedInTaskloopContext);
3448 if (hasTaskloopLocalBound(loopNestOp.getLoopLowerBounds()) ||
3449 hasTaskloopLocalBound(loopNestOp.getLoopUpperBounds()) ||
3450 hasTaskloopLocalBound(loopNestOp.getLoopSteps())) {
3451 return emitOpError() <<
"expects loop bounds and steps to be defined "
3452 "outside of the taskloop.context region";
3463 const TaskloopWrapperOperands &clauses) {
3464 TaskloopWrapperOp::build(builder, state);
3467TaskloopContextOp TaskloopWrapperOp::getTaskloopContext() {
3468 return dyn_cast<TaskloopContextOp>(getOperation()->getParentOp());
3471LogicalResult TaskloopWrapperOp::verify() {
3472 TaskloopContextOp context = getTaskloopContext();
3474 return emitOpError() <<
"expected to be nested in a taskloop context op";
3478LogicalResult TaskloopWrapperOp::verifyRegions() {
3479 if (LoopWrapperInterface nested = getNestedWrapper()) {
3482 <<
"'omp.composite' attribute missing from composite wrapper";
3486 if (!isa<SimdOp>(nested))
3487 return emitError() <<
"only supported nested wrapper is 'omp.simd'";
3488 }
else if (isComposite()) {
3490 <<
"'omp.composite' attribute present in non-composite wrapper";
3514 for (
auto &iv : ivs)
3515 iv.type = loopVarType;
3520 result.addAttribute(
"loop_inclusive", UnitAttr::get(ctx));
3536 "collapse_num_loops",
3541 auto parseTiles = [&]() -> ParseResult {
3545 tiles.push_back(
tile);
3554 if (tiles.size() > 0)
3573 Region ®ion = getRegion();
3575 p <<
" (" << args <<
") : " << args[0].getType() <<
" = ("
3576 << getLoopLowerBounds() <<
") to (" << getLoopUpperBounds() <<
") ";
3577 if (getLoopInclusive())
3579 p <<
"step (" << getLoopSteps() <<
") ";
3580 if (
int64_t numCollapse = getCollapseNumLoops())
3581 if (numCollapse > 1)
3582 p <<
"collapse(" << numCollapse <<
") ";
3585 p <<
"tiles(" << tiles.value() <<
") ";
3591 const LoopNestOperands &clauses) {
3593 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3594 clauses.loopLowerBounds, clauses.loopUpperBounds,
3595 clauses.loopSteps, clauses.loopInclusive,
3599LogicalResult LoopNestOp::verify() {
3600 if (getLoopLowerBounds().empty())
3601 return emitOpError() <<
"must represent at least one loop";
3603 if (getLoopLowerBounds().size() != getIVs().size())
3604 return emitOpError() <<
"number of range arguments and IVs do not match";
3606 for (
auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3607 if (lb.getType() != iv.getType())
3609 <<
"range argument type does not match corresponding IV type";
3612 uint64_t numIVs = getIVs().size();
3614 if (
const auto &numCollapse = getCollapseNumLoops())
3615 if (numCollapse > numIVs)
3617 <<
"collapse value is larger than the number of loops";
3620 if (tiles.value().size() > numIVs)
3621 return emitOpError() <<
"too few canonical loops for tile dimensions";
3623 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3624 return emitOpError() <<
"expects parent op to be a loop wrapper";
3629void LoopNestOp::gatherWrappers(
3632 while (
auto wrapper =
3633 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3634 wrappers.push_back(wrapper);
3643std::tuple<NewCliOp, OpOperand *, OpOperand *>
3649 return {{},
nullptr,
nullptr};
3652 "Unexpected type of cli");
3658 auto op = cast<LoopTransformationInterface>(use.getOwner());
3660 unsigned opnum = use.getOperandNumber();
3661 if (op.isGeneratee(opnum)) {
3662 assert(!gen &&
"Each CLI may have at most one def");
3664 }
else if (op.isApplyee(opnum)) {
3665 assert(!cons &&
"Each CLI may have at most one consumer");
3668 llvm_unreachable(
"Unexpected operand for a CLI");
3672 return {create, gen, cons};
3695 std::string cliName{
"cli"};
3699 .Case([&](CanonicalLoopOp op) {
3702 .Case([&](UnrollHeuristicOp op) -> std::string {
3703 llvm_unreachable(
"heuristic unrolling does not generate a loop");
3705 .Case([&](FuseOp op) -> std::string {
3706 unsigned opnum =
generator->getOperandNumber();
3709 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3710 return "canonloop_fuse";
3714 .Case([&](TileOp op) -> std::string {
3715 auto [generateesFirst, generateesCount] =
3716 op.getGenerateesODSOperandIndexAndLength();
3717 unsigned firstGrid = generateesFirst;
3718 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3719 unsigned end = generateesFirst + generateesCount;
3720 unsigned opnum =
generator->getOperandNumber();
3722 if (firstGrid <= opnum && opnum < firstIntratile) {
3723 unsigned gridnum = opnum - firstGrid + 1;
3724 return (
"grid" + Twine(gridnum)).str();
3726 if (firstIntratile <= opnum && opnum < end) {
3727 unsigned intratilenum = opnum - firstIntratile + 1;
3728 return (
"intratile" + Twine(intratilenum)).str();
3730 llvm_unreachable(
"Unexpected generatee argument");
3732 .DefaultUnreachable(
"TODO: Custom name for this operation");
3735 setNameFn(
result, cliName);
3738LogicalResult NewCliOp::verify() {
3739 Value cli = getResult();
3742 "Unexpected type of cli");
3748 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3750 unsigned opnum = use.getOperandNumber();
3751 if (op.isGeneratee(opnum)) {
3754 emitOpError(
"CLI must have at most one generator");
3756 .
append(
"first generator here:");
3758 .
append(
"second generator here:");
3763 }
else if (op.isApplyee(opnum)) {
3766 emitOpError(
"CLI must have at most one consumer");
3768 .
append(
"first consumer here:")
3772 .
append(
"second consumer here:")
3779 llvm_unreachable(
"Unexpected operand for a CLI");
3787 .
append(
"see consumer here: ")
3810 setNameFn(&getRegion().front(),
"body_entry");
3813void CanonicalLoopOp::getAsmBlockArgumentNames(
Region ®ion,
3821 p <<
'(' << getCli() <<
')';
3822 p <<
' ' << getInductionVar() <<
" : " << getInductionVar().getType()
3823 <<
" in range(" << getTripCount() <<
") ";
3833 CanonicalLoopInfoType cliType =
3834 CanonicalLoopInfoType::get(parser.
getContext());
3859 if (parser.
parseRegion(*region, {inductionVariable}))
3864 result.operands.append(cliOperand);
3870 return mlir::success();
3873LogicalResult CanonicalLoopOp::verify() {
3876 if (!getRegion().empty()) {
3877 Region ®ion = getRegion();
3880 "Canonical loop region must have exactly one argument");
3884 "Region argument must be the same type as the trip count");
3890Value CanonicalLoopOp::getInductionVar() {
return getRegion().getArgument(0); }
3892std::pair<unsigned, unsigned>
3893CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3898std::pair<unsigned, unsigned>
3899CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3900 return getODSOperandIndexAndLength(odsIndex_cli);
3914 p <<
'(' << getApplyee() <<
')';
3921 auto cliType = CanonicalLoopInfoType::get(parser.
getContext());
3944 return mlir::success();
3947std::pair<unsigned, unsigned>
3948UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3949 return getODSOperandIndexAndLength(odsIndex_applyee);
3952std::pair<unsigned, unsigned>
3953UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3964 if (!generatees.empty())
3965 p <<
'(' << llvm::interleaved(generatees) <<
')';
3967 if (!applyees.empty())
3968 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4010 bool isOnlyCanonLoops =
true;
4012 for (
Value applyee : op.getApplyees()) {
4013 auto [create, gen, cons] =
decodeCli(applyee);
4016 return op.emitOpError() <<
"applyee CLI has no generator";
4018 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4019 canonLoops.push_back(loop);
4021 isOnlyCanonLoops =
false;
4026 if (!isOnlyCanonLoops)
4030 for (
auto i : llvm::seq<int>(1, canonLoops.size())) {
4031 auto parentLoop = canonLoops[i - 1];
4032 auto loop = canonLoops[i];
4034 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
4035 return op.emitOpError()
4036 <<
"tiled loop nest must be nested within each other";
4038 parentIVs.insert(parentLoop.getInductionVar());
4043 bool isPerfectlyNested = [&]() {
4044 auto &parentBody = parentLoop.getRegion();
4045 if (!parentBody.hasOneBlock())
4047 auto &parentBlock = parentBody.getBlocks().front();
4049 auto nestedLoopIt = parentBlock.begin();
4050 if (nestedLoopIt == parentBlock.end() ||
4051 (&*nestedLoopIt != loop.getOperation()))
4054 auto termIt = std::next(nestedLoopIt);
4055 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
4058 if (std::next(termIt) != parentBlock.end())
4063 if (!isPerfectlyNested)
4064 return op.emitOpError() <<
"tiled loop nest must be perfectly nested";
4066 if (parentIVs.contains(loop.getTripCount()))
4067 return op.emitOpError() <<
"tiled loop nest must be rectangular";
4084LogicalResult TileOp::verify() {
4085 if (getApplyees().empty())
4086 return emitOpError() <<
"must apply to at least one loop";
4088 if (getSizes().size() != getApplyees().size())
4089 return emitOpError() <<
"there must be one tile size for each applyee";
4091 if (!getGeneratees().empty() &&
4092 2 * getSizes().size() != getGeneratees().size())
4094 <<
"expecting two times the number of generatees than applyees";
4099std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4100 return getODSOperandIndexAndLength(odsIndex_applyees);
4103std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4104 return getODSOperandIndexAndLength(odsIndex_generatees);
4114 if (!generatees.empty())
4115 p <<
'(' << llvm::interleaved(generatees) <<
')';
4117 if (!applyees.empty())
4118 p <<
" <- (" << llvm::interleaved(applyees) <<
')';
4121LogicalResult FuseOp::verify() {
4122 if (getApplyees().size() < 2)
4123 return emitOpError() <<
"must apply to at least two loops";
4125 if (getFirst().has_value() && getCount().has_value()) {
4126 int64_t first = getFirst().value();
4127 int64_t count = getCount().value();
4128 if ((
unsigned)(first + count - 1) > getApplyees().size())
4129 return emitOpError() <<
"the numbers of applyees must be at least first "
4130 "minus one plus count attributes";
4131 if (!getGeneratees().empty() &&
4132 getGeneratees().size() != getApplyees().size() + 1 - count)
4133 return emitOpError() <<
"the number of generatees must be the number of "
4134 "aplyees plus one minus count";
4137 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4139 <<
"in a complete fuse the number of generatees must be exactly 1";
4141 for (
auto &&applyee : getApplyees()) {
4142 auto [create, gen, cons] =
decodeCli(applyee);
4145 return emitOpError() <<
"applyee CLI has no generator";
4146 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4149 <<
"currently only supports omp.canonical_loop as applyee";
4153std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4154 return getODSOperandIndexAndLength(odsIndex_applyees);
4157std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4158 return getODSOperandIndexAndLength(odsIndex_generatees);
4166 const CriticalDeclareOperands &clauses) {
4167 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4170LogicalResult CriticalDeclareOp::verify() {
4175 if (getNameAttr()) {
4176 SymbolRefAttr symbolRef = getNameAttr();
4180 return emitOpError() <<
"expected symbol reference " << symbolRef
4181 <<
" to point to a critical declaration";
4201 return op.
emitOpError() <<
"must be nested inside of a loop";
4205 if (
auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4206 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4208 return op.
emitOpError() <<
"the enclosing worksharing-loop region must "
4209 "have an ordered clause";
4211 if (hasRegion && orderedAttr.getInt() != 0)
4212 return op.
emitOpError() <<
"the enclosing loop's ordered clause must not "
4213 "have a parameter present";
4215 if (!hasRegion && orderedAttr.getInt() == 0)
4216 return op.
emitOpError() <<
"the enclosing loop's ordered clause must "
4217 "have a parameter present";
4218 }
else if (!isa<SimdOp>(wrapper)) {
4219 return op.
emitOpError() <<
"must be nested inside of a worksharing, simd "
4220 "or worksharing simd loop";
4226 const OrderedOperands &clauses) {
4227 OrderedOp::build(builder, state, clauses.doacrossDependType,
4228 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4231LogicalResult OrderedOp::verify() {
4235 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4236 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4237 return emitOpError() <<
"number of variables in depend clause does not "
4238 <<
"match number of iteration variables in the "
4245 const OrderedRegionOperands &clauses) {
4246 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4256 const TaskwaitOperands &clauses) {
4258 TaskwaitOp::build(builder, state,
nullptr,
4267LogicalResult AtomicReadOp::verify() {
4268 if (verifyCommon().
failed())
4269 return mlir::failure();
4271 if (
auto mo = getMemoryOrder()) {
4272 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4273 *mo == ClauseMemoryOrderKind::Release) {
4275 "memory-order must not be acq_rel or release for atomic reads");
4285LogicalResult AtomicWriteOp::verify() {
4286 if (verifyCommon().
failed())
4287 return mlir::failure();
4289 if (
auto mo = getMemoryOrder()) {
4290 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4291 *mo == ClauseMemoryOrderKind::Acquire) {
4293 "memory-order must not be acq_rel or acquire for atomic writes");
4303LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4309 if (
Value writeVal = op.getWriteOpVal()) {
4311 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4317LogicalResult AtomicUpdateOp::verify() {
4318 if (verifyCommon().
failed())
4319 return mlir::failure();
4321 if (
auto mo = getMemoryOrder()) {
4322 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4323 *mo == ClauseMemoryOrderKind::Acquire) {
4325 "memory-order must not be acq_rel or acquire for atomic updates");
4332LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4338AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4339 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4341 return dyn_cast<AtomicReadOp>(getSecondOp());
4344AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4345 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4347 return dyn_cast<AtomicWriteOp>(getSecondOp());
4350AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4351 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4353 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4356LogicalResult AtomicCaptureOp::verify() {
4360LogicalResult AtomicCaptureOp::verifyRegions() {
4361 if (verifyRegionsCommon().
failed())
4362 return mlir::failure();
4364 if (getFirstOp()->getAttr(
"hint") || getSecondOp()->getAttr(
"hint"))
4366 "operations inside capture region must not have hint clause");
4368 if (getFirstOp()->getAttr(
"memory_order") ||
4369 getSecondOp()->getAttr(
"memory_order"))
4371 "operations inside capture region must not have memory_order clause");
4380 const CancelOperands &clauses) {
4381 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4394LogicalResult CancelOp::verify() {
4395 ClauseCancellationConstructType cct = getCancelDirective();
4398 if (!structuralParent)
4399 return emitOpError() <<
"Orphaned cancel construct";
4401 if ((cct == ClauseCancellationConstructType::Parallel) &&
4402 !mlir::isa<ParallelOp>(structuralParent)) {
4403 return emitOpError() <<
"cancel parallel must appear "
4404 <<
"inside a parallel region";
4406 if (cct == ClauseCancellationConstructType::Loop) {
4409 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->
getParentOp());
4413 <<
"cancel loop must appear inside a worksharing-loop region";
4415 if (wsloopOp.getNowaitAttr()) {
4416 return emitError() <<
"A worksharing construct that is canceled "
4417 <<
"must not have a nowait clause";
4419 if (wsloopOp.getOrderedAttr()) {
4420 return emitError() <<
"A worksharing construct that is canceled "
4421 <<
"must not have an ordered clause";
4424 }
else if (cct == ClauseCancellationConstructType::Sections) {
4428 mlir::dyn_cast<SectionsOp>(structuralParent->
getParentOp());
4430 return emitOpError() <<
"cancel sections must appear "
4431 <<
"inside a sections region";
4433 if (sectionsOp.getNowait()) {
4434 return emitError() <<
"A sections construct that is canceled "
4435 <<
"must not have a nowait clause";
4438 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4439 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4440 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->
getParentOp()))) {
4441 return emitOpError() <<
"cancel taskgroup must appear "
4442 <<
"inside a task region";
4452 const CancellationPointOperands &clauses) {
4453 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4456LogicalResult CancellationPointOp::verify() {
4457 ClauseCancellationConstructType cct = getCancelDirective();
4460 if (!structuralParent)
4461 return emitOpError() <<
"Orphaned cancellation point";
4463 if ((cct == ClauseCancellationConstructType::Parallel) &&
4464 !mlir::isa<ParallelOp>(structuralParent)) {
4465 return emitOpError() <<
"cancellation point parallel must appear "
4466 <<
"inside a parallel region";
4470 if ((cct == ClauseCancellationConstructType::Loop) &&
4471 !mlir::isa<WsloopOp>(structuralParent->
getParentOp())) {
4472 return emitOpError() <<
"cancellation point loop must appear "
4473 <<
"inside a worksharing-loop region";
4475 if ((cct == ClauseCancellationConstructType::Sections) &&
4476 !mlir::isa<omp::SectionOp>(structuralParent)) {
4477 return emitOpError() <<
"cancellation point sections must appear "
4478 <<
"inside a sections region";
4480 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4481 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4482 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->
getParentOp()))) {
4483 return emitOpError() <<
"cancellation point taskgroup must appear "
4484 <<
"inside a task region";
4493LogicalResult MapBoundsOp::verify() {
4494 auto extent = getExtent();
4496 if (!extent && !upperbound)
4497 return emitError(
"expected extent or upperbound.");
4504 PrivateClauseOp::build(
4505 odsBuilder, odsState, symName, type,
4506 DataSharingClauseTypeAttr::get(odsBuilder.
getContext(),
4507 DataSharingClauseType::Private));
4510LogicalResult PrivateClauseOp::verifyRegions() {
4511 Type argType = getArgType();
4512 auto verifyTerminator = [&](
Operation *terminator,
4513 bool yieldsValue) -> LogicalResult {
4517 if (!llvm::isa<YieldOp>(terminator))
4519 <<
"expected exit block terminator to be an `omp.yield` op.";
4521 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4522 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4525 if (yieldedTypes.empty())
4529 <<
"Did not expect any values to be yielded.";
4532 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4536 <<
"Invalid yielded value. Expected type: " << argType
4539 if (yieldedTypes.empty())
4542 error << yieldedTypes;
4548 StringRef regionName,
4549 bool yieldsValue) -> LogicalResult {
4550 assert(!region.
empty());
4554 <<
"`" << regionName <<
"`: "
4555 <<
"expected " << expectedNumArgs
4558 for (
Block &block : region) {
4560 if (!block.mightHaveTerminator())
4563 if (
failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4571 for (
Region *region : getRegions())
4572 for (
Type ty : region->getArgumentTypes())
4574 return emitError() <<
"Region argument type mismatch: got " << ty
4575 <<
" expected " << argType <<
".";
4578 if (!initRegion.
empty() &&
4583 DataSharingClauseType dsType = getDataSharingType();
4585 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4586 return emitError(
"`private` clauses do not require a `copy` region.");
4588 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4590 "`firstprivate` clauses require at least a `copy` region.");
4592 if (dsType == DataSharingClauseType::FirstPrivate &&
4597 if (!getDeallocRegion().empty() &&
4610 const MaskedOperands &clauses) {
4611 MaskedOp::build(builder, state, clauses.filteredThreadId);
4619 const ScanOperands &clauses) {
4620 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4623LogicalResult ScanOp::verify() {
4624 if (hasExclusiveVars() == hasInclusiveVars())
4626 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4627 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4628 if (parentWsLoopOp.getReductionModAttr() &&
4629 parentWsLoopOp.getReductionModAttr().getValue() ==
4630 ReductionModifier::inscan)
4633 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4634 if (parentSimdOp.getReductionModAttr() &&
4635 parentSimdOp.getReductionModAttr().getValue() ==
4636 ReductionModifier::inscan)
4639 return emitError(
"SCAN directive needs to be enclosed within a parent "
4640 "worksharing loop construct or SIMD construct with INSCAN "
4641 "reduction modifier");
4646LogicalResult AllocateDirOp::verify() {
4647 std::optional<uint64_t> align = this->getAlign();
4649 if (align.has_value()) {
4650 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4651 return emitError() <<
"ALIGN value : " << align.value()
4652 <<
" must be power of 2";
4662mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4663 return getInTypeAttr().getValue();
4672 bool hasOperands =
false;
4673 std::int32_t typeparamsSize = 0;
4679 return mlir::failure();
4681 return mlir::failure();
4683 return mlir::failure();
4687 return mlir::failure();
4688 result.addAttribute(
"in_type", mlir::TypeAttr::get(intype));
4695 return mlir::failure();
4696 typeparamsSize = operands.size();
4699 std::int32_t shapeSize = 0;
4703 return mlir::failure();
4704 shapeSize = operands.size() - typeparamsSize;
4706 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4707 typeVec.push_back(idxTy);
4713 return mlir::failure();
4718 return mlir::failure();
4721 result.addAttribute(
"operandSegmentSizes",
4725 return mlir::failure();
4726 return mlir::success();
4741 if (!getTypeparams().empty()) {
4742 p <<
'(' << getTypeparams() <<
" : " << getTypeparams().getTypes() <<
')';
4749 {
"in_type",
"operandSegmentSizes"});
4752llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4754 if (!mlir::dyn_cast<IntegerType>(outType))
4756 return mlir::success();
4763LogicalResult WorkdistributeOp::verify() {
4765 Region ®ion = getRegion();
4770 if (entryBlock.
empty())
4771 return emitOpError(
"region must contain a structured block");
4773 bool hasTerminator =
false;
4774 for (
Block &block : region) {
4775 if (isa<TerminatorOp>(block.back())) {
4776 if (hasTerminator) {
4777 return emitOpError(
"region must have exactly one terminator");
4779 hasTerminator =
true;
4782 if (!hasTerminator) {
4783 return emitOpError(
"region must be terminated with omp.terminator");
4787 if (isa<BarrierOp>(op)) {
4789 "explicit barriers are not allowed in workdistribute region");
4792 if (isa<ParallelOp>(op)) {
4794 "nested parallel constructs not allowed in workdistribute");
4796 if (isa<TeamsOp>(op)) {
4798 "nested teams constructs not allowed in workdistribute");
4802 if (walkResult.wasInterrupted())
4806 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4807 return emitOpError(
"workdistribute must be nested under teams");
4815LogicalResult DeclareSimdOp::verify() {
4818 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4820 return emitOpError() <<
"must be nested inside a function";
4822 if (getInbranch() && getNotinbranch())
4823 return emitOpError(
"cannot have both 'inbranch' and 'notinbranch'");
4833 const DeclareSimdOperands &clauses) {
4835 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4837 clauses.linearVars, clauses.linearStepVars,
4838 clauses.linearVarTypes, clauses.linearModifiers,
4839 clauses.notinbranch, clauses.simdlen,
4840 clauses.uniformVars);
4857 return mlir::failure();
4858 return mlir::success();
4865 for (
unsigned i = 0; i < uniformVars.size(); ++i) {
4868 p << uniformVars[i] <<
" : " << uniformTypes[i];
4883 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
4884 [&]() -> ParseResult {
return success(); })))
4918 OpAsmParser::Argument &arg = ivArgs.emplace_back();
4919 if (parser.parseArgument(arg))
4923 if (succeeded(parser.parseOptionalColon())) {
4924 if (parser.parseType(arg.type))
4927 arg.type = parser.getBuilder().getIndexType();
4939 OpAsmParser::UnresolvedOperand lb, ub, st;
4940 if (parser.parseOperand(lb) || parser.parseKeyword(
"to") ||
4941 parser.parseOperand(ub) || parser.parseKeyword(
"step") ||
4942 parser.parseOperand(st))
4947 steps.push_back(st);
4955 if (ivArgs.size() != lbs.size())
4957 <<
"mismatch: " << ivArgs.size() <<
" variables but " << lbs.size()
4960 for (
auto &arg : ivArgs) {
4961 lbTypes.push_back(arg.type);
4962 ubTypes.push_back(arg.type);
4963 stepTypes.push_back(arg.type);
4983 for (
unsigned i = 0, e = lbs.size(); i < e; ++i) {
4986 p << lbs[i] <<
" to " << ubs[i] <<
" step " << steps[i];
4994LogicalResult IteratorOp::verify() {
4995 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().
getType());
4997 return emitOpError() <<
"result must be omp.iterated<entry_ty>";
4999 for (
auto [lb,
ub, step] : llvm::zip_equal(
5000 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
5002 return emitOpError() <<
"loop step must not be zero";
5006 IntegerAttr stepAttr;
5012 const APInt &lbVal = lbAttr.getValue();
5013 const APInt &ubVal = ubAttr.getValue();
5014 const APInt &stepVal = stepAttr.getValue();
5015 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
5016 return emitOpError() <<
"positive loop step requires lower bound to be "
5017 "less than or equal to upper bound";
5018 if (stepVal.isNegative() && lbVal.slt(ubVal))
5019 return emitOpError() <<
"negative loop step requires lower bound to be "
5020 "greater than or equal to upper bound";
5023 Block &
b = getRegion().front();
5024 auto yield = llvm::dyn_cast<omp::YieldOp>(
b.getTerminator());
5027 return emitOpError() <<
"region must be terminated by omp.yield";
5029 if (yield.getNumOperands() != 1)
5031 <<
"omp.yield in omp.iterator region must yield exactly one value";
5033 mlir::Type yieldedTy = yield.getOperand(0).getType();
5034 mlir::Type elemTy = iteratedTy.getElementType();
5036 if (yieldedTy != elemTy)
5037 return emitOpError() <<
"omp.iterated element type (" << elemTy
5038 <<
") does not match omp.yield operand type ("
5039 << yieldedTy <<
")";
5044#define GET_ATTRDEF_CLASSES
5045#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
5047#define GET_OP_CLASSES
5048#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
5050#define GET_TYPEDEF_CLASSES
5051#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region ®ion, const Twine &name)
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region ®ion, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region ®ion, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > ®ionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region ®ion, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > ©privateVars, SmallVectorImpl< Type > ©privateTypes, ArrayAttr ©privateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
SuccessorRange getSuccessors()
BlockArgListType getArguments()
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
IntegerType getIntegerType(unsigned width)
MLIRContext * getContext() const
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
A class for computing basic dominance information.
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
This class represents an operand of an operation.
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
type_range getType() const
Operation is the basic unit of execution within MLIR.
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Block * getBlock()
Returns the operation block that contains this operation.
unsigned getNumRegions()
Returns the number of regions held by this operation.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
user_range getUsers()
Returns a range of all users.
Region * getParentRegion()
Returns the region to which the instruction belongs.
MLIRContext * getContext()
Return the context this operation is associated with.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
BlockArgListType getArguments()
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
iterator_range< OpIterator > getOps()
unsigned getNumArguments()
Location getLoc()
Return a location for this region.
BlockArgument getArgument(unsigned i)
Operation * getParentOp()
Return the parent operation this region is attached to.
BlockListType & getBlocks()
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Type getType() const
Return the type of this value.
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
A utility result that is used to signal how to proceed with an ongoing walk:
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
llvm::TypeSwitch< T, ResultT > TypeSwitch
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.