21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
28 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
36 static bool isScalarLikeType(
Type type) {
40 struct MemRefPointerLikeModel
41 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
44 return cast<MemRefType>(pointer).getElementType();
46 mlir::acc::VariableTypeCategory
49 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
50 return mappableTy.getTypeCategory(varPtr);
52 auto memrefTy = cast<MemRefType>(pointer);
53 if (!memrefTy.hasRank()) {
56 return mlir::acc::VariableTypeCategory::uncategorized;
59 if (memrefTy.getRank() == 0) {
60 if (isScalarLikeType(memrefTy.getElementType())) {
61 return mlir::acc::VariableTypeCategory::scalar;
65 return mlir::acc::VariableTypeCategory::uncategorized;
69 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
70 return mlir::acc::VariableTypeCategory::array;
74 struct LLVMPointerPointerLikeModel
75 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76 LLVM::LLVMPointerType> {
84 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
85 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
88 if (existingDeviceTypes)
89 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
91 if (newDeviceTypes.empty())
92 deviceTypes.push_back(
95 for (DeviceType DT : newDeviceTypes)
107 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
108 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
113 if (existingDeviceTypes)
114 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
116 if (newDeviceTypes.empty()) {
117 argCollection.
append(arguments);
118 segments.push_back(arguments.size());
119 deviceTypes.push_back(
123 for (DeviceType DT : newDeviceTypes) {
124 argCollection.
append(arguments);
125 segments.push_back(arguments.size());
133 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
134 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
138 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
139 newDeviceTypes, arguments,
140 argCollection, segments);
148 void OpenACCDialect::initialize() {
151 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
154 #define GET_ATTRDEF_LIST
155 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
158 #define GET_TYPEDEF_LIST
159 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
165 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
166 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
175 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
181 mlir::acc::DeviceType deviceType) {
185 for (
auto attr : *arrayAttr) {
186 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
187 if (deviceTypeAttr.getValue() == deviceType)
195 std::optional<mlir::ArrayAttr> deviceTypes) {
200 llvm::interleaveComma(*deviceTypes, p,
206 mlir::acc::DeviceType deviceType) {
207 unsigned segmentIdx = 0;
208 for (
auto attr : segments) {
209 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
210 if (deviceTypeAttr.getValue() == deviceType)
211 return std::make_optional(segmentIdx);
221 mlir::acc::DeviceType deviceType) {
223 return range.take_front(0);
224 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
225 int32_t nbOperandsBefore = 0;
226 for (
unsigned i = 0; i < *pos; ++i)
227 nbOperandsBefore += (*segments)[i];
228 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
230 return range.take_front(0);
237 std::optional<mlir::ArrayAttr> hasWaitDevnum,
238 mlir::acc::DeviceType deviceType) {
241 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
242 if (hasWaitDevnum->getValue()[*pos])
253 std::optional<mlir::ArrayAttr> hasWaitDevnum,
254 mlir::acc::DeviceType deviceType) {
259 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
260 if (hasWaitDevnum && *hasWaitDevnum) {
261 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
262 if (boolAttr.getValue())
263 return range.drop_front(1);
269 template <
typename Op>
271 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
273 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
278 op.hasAsyncOnly(dtype))
280 "asyncOnly attribute cannot appear with asyncOperand");
285 op.hasWaitOnly(dtype))
286 return op.
emitError(
"wait attribute cannot appear with waitOperands");
291 template <
typename Op>
294 return op.
emitError(
"must have var operand");
296 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
297 mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
302 return op.
emitError(
"var must be mappable or pointer-like (not both)");
305 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
306 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
307 return op.
emitError(
"var must be mappable or pointer-like");
309 if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
310 op.getVarType() != op.getVar().getType())
311 return op.
emitError(
"varType must match when var is mappable");
316 template <
typename Op>
318 if (op.getVar().getType() != op.getAccVar().getType())
319 return op.
emitError(
"input and output types must match");
341 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
362 if (failed(parser.
parseType(accVarType)))
372 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
384 mlir::TypeAttr &varTypeAttr) {
385 if (failed(parser.
parseType(varPtrType)))
401 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
403 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
412 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
420 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
421 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
423 if (typeToCheckAgainst != varType) {
434 auto extent = getExtent();
435 auto upperbound = getUpperbound();
436 if (!extent && !upperbound)
437 return emitError(
"expected extent or upperbound.");
447 "data clause associated with private operation must match its intent");
458 return emitError(
"data clause associated with firstprivate operation must "
470 return emitError(
"data clause associated with reduction operation must "
482 return emitError(
"data clause associated with deviceptr operation must "
497 "data clause associated with present operation must match its intent");
510 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
515 "data clause associated with copyin operation must match its intent"
516 " or specify original clause this operation was decomposed from");
524 bool acc::CopyinOp::isCopyinReadonly() {
525 return getDataClause() == acc::DataClause::acc_copyin_readonly;
538 "data clause associated with create operation must match its intent"
539 " or specify original clause this operation was decomposed from");
547 bool acc::CreateOp::isCreateZero() {
549 return getDataClause() == acc::DataClause::acc_create_zero ||
558 return emitError(
"data clause associated with no_create operation must "
573 "data clause associated with attach operation must match its intent");
586 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
587 return emitError(
"data clause associated with device_resident operation "
588 "must match its intent");
603 "data clause associated with link operation must match its intent");
621 "data clause associated with copyout operation must match its intent"
622 " or specify original clause this operation was decomposed from");
624 return emitError(
"must have both host and device pointers");
632 bool acc::CopyoutOp::isCopyoutZero() {
648 getDataClause() != acc::DataClause::acc_declare_device_resident &&
651 "data clause associated with delete operation must match its intent"
652 " or specify original clause this operation was decomposed from");
654 return emitError(
"must have device pointer");
666 "data clause associated with detach operation must match its intent"
667 " or specify original clause this operation was decomposed from");
669 return emitError(
"must have device pointer");
681 "data clause associated with host operation must match its intent"
682 " or specify original clause this operation was decomposed from");
684 return emitError(
"must have both host and device pointers");
699 "data clause associated with device operation must match its intent"
700 " or specify original clause this operation was decomposed from");
715 "data clause associated with use_device operation must match its intent"
716 " or specify original clause this operation was decomposed from");
732 "data clause associated with cache operation must match its intent"
733 " or specify original clause this operation was decomposed from");
741 template <
typename StructureOp>
743 unsigned nRegions = 1) {
746 for (
unsigned i = 0; i < nRegions; ++i)
747 regions.push_back(state.addRegion());
749 for (
Region *region : regions)
757 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
764 template <
typename OpTy>
768 LogicalResult matchAndRewrite(OpTy op,
771 Value ifCond = op.getIfCond();
775 IntegerAttr constAttr;
778 if (constAttr.getInt())
779 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
791 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
803 template <
typename OpTy>
804 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
807 LogicalResult matchAndRewrite(OpTy op,
810 Value ifCond = op.getIfCond();
814 IntegerAttr constAttr;
817 if (constAttr.getInt())
818 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
833 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
835 if (optional && region.
empty())
839 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
843 return op->
emitOpError() <<
"expects " << regionName
846 << regionType <<
" type";
849 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
850 if (yieldOp.getOperands().size() != 1 ||
851 yieldOp.getOperands().getTypes()[0] != type)
852 return op->
emitOpError() <<
"expects " << regionName
854 "yield a value of the "
855 << regionType <<
" type";
861 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
863 "privatization",
"init",
getType(),
867 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
877 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
879 "privatization",
"init",
getType(),
883 if (getCopyRegion().empty())
884 return emitOpError() <<
"expects non-empty copy region";
889 return emitOpError() <<
"expects copy region with two arguments of the "
890 "privatization type";
892 if (getDestroyRegion().empty())
896 "privatization",
"destroy",
907 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
913 if (getCombinerRegion().empty())
914 return emitOpError() <<
"expects non-empty combiner region";
916 Block &reductionBlock = getCombinerRegion().
front();
920 return emitOpError() <<
"expects combiner region with the first two "
921 <<
"arguments of the reduction type";
923 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
924 if (yieldOp.getOperands().size() != 1 ||
925 yieldOp.getOperands().getTypes()[0] !=
getType())
926 return emitOpError() <<
"expects combiner region to yield a value "
927 "of the reduction type";
943 if (parser.parseAttribute(attributes.emplace_back()) ||
944 parser.parseArrow() ||
945 parser.parseOperand(operands.emplace_back()) ||
946 parser.parseColonType(types.emplace_back()))
960 std::optional<mlir::ArrayAttr> attributes) {
961 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
962 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
963 << std::get<1>(it).getType();
972 template <
typename Op>
976 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
977 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
978 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
979 operand.getDefiningOp()))
981 "expect data entry/exit operation or acc.getdeviceptr "
986 template <
typename Op>
990 llvm::StringRef symbolName,
bool checkOperandType =
true) {
991 if (!operands.empty()) {
992 if (!attributes || attributes->size() != operands.size())
994 <<
"expected as many " << symbolName <<
" symbol reference as "
995 << operandName <<
" operands";
999 <<
"unexpected " << symbolName <<
" symbol reference";
1004 for (
auto args : llvm::zip(operands, *attributes)) {
1007 if (!set.insert(operand).second)
1009 << operandName <<
" operand appears more than once";
1012 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1013 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1016 <<
"expected symbol reference " << symbolRef <<
" to point to a "
1017 << operandName <<
" declaration";
1019 if (checkOperandType && decl.getType() && decl.getType() != varType)
1020 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
1021 <<
") to be the same type as " << operandName
1022 <<
" declaration (" << decl.getType() <<
")";
1028 unsigned ParallelOp::getNumDataOperands() {
1029 return getReductionOperands().size() + getPrivateOperands().size() +
1030 getFirstprivateOperands().size() + getDataClauseOperands().size();
1033 Value ParallelOp::getDataOperand(
unsigned i) {
1035 numOptional += getNumGangs().size();
1036 numOptional += getNumWorkers().size();
1037 numOptional += getVectorLength().size();
1038 numOptional += getIfCond() ? 1 : 0;
1039 numOptional += getSelfCond() ? 1 : 0;
1040 return getOperand(getWaitOperands().size() + numOptional + i);
1043 template <
typename Op>
1045 ArrayAttr deviceTypes,
1046 llvm::StringRef keyword) {
1047 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1048 return op.
emitOpError() << keyword <<
" operands count must match "
1049 << keyword <<
" device_type count";
1053 template <
typename Op>
1056 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1057 std::size_t numOperandsInSegments = 0;
1058 std::size_t nbOfSegments = 0;
1061 for (
auto segCount : segments.
asArrayRef()) {
1062 if (maxInSegment != 0 && segCount > maxInSegment)
1063 return op.
emitOpError() << keyword <<
" expects a maximum of "
1064 << maxInSegment <<
" values per segment";
1065 numOperandsInSegments += segCount;
1070 if ((numOperandsInSegments != operands.size()) ||
1071 (!deviceTypes && !operands.empty()))
1073 << keyword <<
" operand count does not match count in segments";
1074 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1076 << keyword <<
" segment count does not match device_type count";
1081 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1082 *
this, getPrivatizations(), getPrivateOperands(),
"private",
1083 "privatizations",
false)))
1085 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1086 *
this, getFirstprivatizations(), getFirstprivateOperands(),
1087 "firstprivate",
"firstprivatizations",
false)))
1089 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1090 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1091 "reductions",
false)))
1095 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1096 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1100 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1101 getWaitOperandsDeviceTypeAttr(),
"wait")))
1105 getNumWorkersDeviceTypeAttr(),
1110 getVectorLengthDeviceTypeAttr(),
1115 getAsyncOperandsDeviceTypeAttr(),
1119 if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*
this)))
1122 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
1128 mlir::acc::DeviceType deviceType) {
1131 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1136 bool acc::ParallelOp::hasAsyncOnly() {
1140 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1148 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1153 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1158 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1163 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1168 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1170 getVectorLength(), deviceType);
1178 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1180 getNumGangsSegments(), deviceType);
1183 bool acc::ParallelOp::hasWaitOnly() {
1187 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1196 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1198 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1199 getHasWaitDevnum(), deviceType);
1206 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1208 getWaitOperandsSegments(), getHasWaitDevnum(),
1224 odsBuilder, odsState, asyncOperands,
nullptr,
1225 nullptr, waitOperands,
nullptr,
1227 nullptr, numGangs,
nullptr,
1228 nullptr, numWorkers,
1229 nullptr, vectorLength,
1230 nullptr, ifCond, selfCond,
1231 nullptr, reductionOperands,
nullptr,
1232 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
1233 nullptr, dataClauseOperands,
1237 void acc::ParallelOp::addNumWorkersOperand(
1240 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1241 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1242 getNumWorkersMutable()));
1244 void acc::ParallelOp::addVectorLengthOperand(
1247 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1248 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1249 getVectorLengthMutable()));
1252 void acc::ParallelOp::addAsyncOnly(
1254 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1255 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1258 void acc::ParallelOp::addAsyncOperand(
1261 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1262 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1263 getAsyncOperandsMutable()));
1266 void acc::ParallelOp::addNumGangsOperands(
1270 if (getNumGangsSegments())
1271 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1273 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1274 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1275 getNumGangsMutable(), segments));
1277 setNumGangsSegments(segments);
1279 void acc::ParallelOp::addWaitOnly(
1281 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1282 effectiveDeviceTypes));
1284 void acc::ParallelOp::addWaitOperands(
1289 if (getWaitOperandsSegments())
1290 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1292 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1293 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1294 getWaitOperandsMutable(), segments));
1295 setWaitOperandsSegments(segments);
1298 if (getHasWaitDevnumAttr())
1299 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1302 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1319 int32_t crtOperandsSize = operands.size();
1322 if (parser.parseOperand(operands.emplace_back()) ||
1323 parser.parseColonType(types.emplace_back()))
1328 seg.push_back(operands.size() - crtOperandsSize);
1352 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1354 p <<
" [" << attr <<
"]";
1359 std::optional<mlir::ArrayAttr> deviceTypes,
1360 std::optional<mlir::DenseI32ArrayAttr> segments) {
1362 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1364 llvm::interleaveComma(
1365 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1366 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1386 int32_t crtOperandsSize = operands.size();
1390 if (parser.parseOperand(operands.emplace_back()) ||
1391 parser.parseColonType(types.emplace_back()))
1397 seg.push_back(operands.size() - crtOperandsSize);
1423 std::optional<mlir::DenseI32ArrayAttr> segments) {
1425 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1427 llvm::interleaveComma(
1428 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1429 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1442 mlir::ArrayAttr &keywordOnly) {
1446 bool needCommaBeforeOperands =
false;
1459 if (parser.parseAttribute(keywordAttrs.emplace_back()))
1466 needCommaBeforeOperands =
true;
1469 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1476 int32_t crtOperandsSize = operands.size();
1488 if (parser.parseOperand(operands.emplace_back()) ||
1489 parser.parseColonType(types.emplace_back()))
1495 seg.push_back(operands.size() - crtOperandsSize);
1524 if (attrs->size() != 1)
1526 if (
auto deviceTypeAttr =
1527 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1534 std::optional<mlir::ArrayAttr> deviceTypes,
1535 std::optional<mlir::DenseI32ArrayAttr> segments,
1536 std::optional<mlir::ArrayAttr> hasDevNum,
1537 std::optional<mlir::ArrayAttr> keywordOnly) {
1550 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1552 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1553 if (boolAttr && boolAttr.getValue())
1555 llvm::interleaveComma(
1556 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1557 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1574 if (parser.parseOperand(operands.emplace_back()) ||
1575 parser.parseColonType(types.emplace_back()))
1577 if (succeeded(parser.parseOptionalLSquare())) {
1578 if (parser.parseAttribute(attributes.emplace_back()) ||
1579 parser.parseRSquare())
1582 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1583 parser.getContext(), mlir::acc::DeviceType::None));
1597 std::optional<mlir::ArrayAttr> deviceTypes) {
1600 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
1601 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
1610 mlir::ArrayAttr &keywordOnlyDeviceType) {
1613 bool needCommaBeforeOperands =
false;
1619 keywordOnlyDeviceType =
1628 if (parser.parseAttribute(
1629 keywordOnlyDeviceTypeAttributes.emplace_back()))
1636 needCommaBeforeOperands =
true;
1639 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1644 if (parser.parseOperand(operands.emplace_back()) ||
1645 parser.parseColonType(types.emplace_back()))
1647 if (succeeded(parser.parseOptionalLSquare())) {
1648 if (parser.parseAttribute(attributes.emplace_back()) ||
1649 parser.parseRSquare())
1652 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1653 parser.getContext(), mlir::acc::DeviceType::None));
1659 if (failed(parser.parseRParen()))
1671 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1673 if (operands.begin() == operands.end() &&
1689 std::optional<OpAsmParser::UnresolvedOperand> &operand,
1690 mlir::Type &operandType, mlir::UnitAttr &attr) {
1703 if (failed(parser.
parseType(operandType)))
1713 std::optional<mlir::Value> operand,
1715 mlir::UnitAttr attr) {
1737 if (parser.parseOperand(operands.emplace_back()))
1745 if (parser.parseType(types.emplace_back()))
1760 mlir::UnitAttr attr) {
1765 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
1767 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
1773 mlir::acc::CombinedConstructsTypeAttr &attr) {
1776 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1779 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1782 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1785 "expected compute construct name");
1793 mlir::acc::CombinedConstructsTypeAttr attr) {
1795 switch (attr.getValue()) {
1796 case mlir::acc::CombinedConstructsType::KernelsLoop:
1799 case mlir::acc::CombinedConstructsType::ParallelLoop:
1802 case mlir::acc::CombinedConstructsType::SerialLoop:
1813 unsigned SerialOp::getNumDataOperands() {
1814 return getReductionOperands().size() + getPrivateOperands().size() +
1815 getFirstprivateOperands().size() + getDataClauseOperands().size();
1818 Value SerialOp::getDataOperand(
unsigned i) {
1820 numOptional += getIfCond() ? 1 : 0;
1821 numOptional += getSelfCond() ? 1 : 0;
1822 return getOperand(getWaitOperands().size() + numOptional + i);
1825 bool acc::SerialOp::hasAsyncOnly() {
1829 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1837 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1842 bool acc::SerialOp::hasWaitOnly() {
1846 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1855 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1857 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1858 getHasWaitDevnum(), deviceType);
1865 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1867 getWaitOperandsSegments(), getHasWaitDevnum(),
1872 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1873 *
this, getPrivatizations(), getPrivateOperands(),
"private",
1874 "privatizations",
false)))
1876 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1877 *
this, getFirstprivatizations(), getFirstprivateOperands(),
1878 "firstprivate",
"firstprivatizations",
false)))
1880 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1881 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1882 "reductions",
false)))
1886 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1887 getWaitOperandsDeviceTypeAttr(),
"wait")))
1891 getAsyncOperandsDeviceTypeAttr(),
1895 if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*
this)))
1898 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
1901 void acc::SerialOp::addAsyncOnly(
1903 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1904 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1907 void acc::SerialOp::addAsyncOperand(
1910 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1911 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1912 getAsyncOperandsMutable()));
1915 void acc::SerialOp::addWaitOnly(
1917 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1918 effectiveDeviceTypes));
1920 void acc::SerialOp::addWaitOperands(
1925 if (getWaitOperandsSegments())
1926 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1928 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1929 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1930 getWaitOperandsMutable(), segments));
1931 setWaitOperandsSegments(segments);
1934 if (getHasWaitDevnumAttr())
1935 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1938 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1947 unsigned KernelsOp::getNumDataOperands() {
1948 return getDataClauseOperands().size();
1951 Value KernelsOp::getDataOperand(
unsigned i) {
1953 numOptional += getWaitOperands().size();
1954 numOptional += getNumGangs().size();
1955 numOptional += getNumWorkers().size();
1956 numOptional += getVectorLength().size();
1957 numOptional += getIfCond() ? 1 : 0;
1958 numOptional += getSelfCond() ? 1 : 0;
1959 return getOperand(numOptional + i);
1962 bool acc::KernelsOp::hasAsyncOnly() {
1966 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1974 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1979 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1984 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1989 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1994 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1996 getVectorLength(), deviceType);
2004 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2006 getNumGangsSegments(), deviceType);
2009 bool acc::KernelsOp::hasWaitOnly() {
2013 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2022 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2024 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2025 getHasWaitDevnum(), deviceType);
2032 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2034 getWaitOperandsSegments(), getHasWaitDevnum(),
2040 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2041 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2045 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2046 getWaitOperandsDeviceTypeAttr(),
"wait")))
2050 getNumWorkersDeviceTypeAttr(),
2055 getVectorLengthDeviceTypeAttr(),
2060 getAsyncOperandsDeviceTypeAttr(),
2064 if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*
this)))
2067 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
2070 void acc::KernelsOp::addNumWorkersOperand(
2073 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2074 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2075 getNumWorkersMutable()));
2078 void acc::KernelsOp::addVectorLengthOperand(
2081 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2082 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2083 getVectorLengthMutable()));
2085 void acc::KernelsOp::addAsyncOnly(
2087 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2088 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2091 void acc::KernelsOp::addAsyncOperand(
2094 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2095 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2096 getAsyncOperandsMutable()));
2099 void acc::KernelsOp::addNumGangsOperands(
2103 if (getNumGangsSegmentsAttr())
2104 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2106 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2107 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2108 getNumGangsMutable(), segments));
2110 setNumGangsSegments(segments);
2113 void acc::KernelsOp::addWaitOnly(
2115 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2116 effectiveDeviceTypes));
2118 void acc::KernelsOp::addWaitOperands(
2123 if (getWaitOperandsSegments())
2124 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2126 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2127 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2128 getWaitOperandsMutable(), segments));
2129 setWaitOperandsSegments(segments);
2132 if (getHasWaitDevnumAttr())
2133 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2136 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2146 if (getDataClauseOperands().empty())
2147 return emitError(
"at least one operand must appear on the host_data "
2150 for (
mlir::Value operand : getDataClauseOperands())
2151 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2152 return emitError(
"expect data entry operation as defining op");
2158 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2170 bool &needCommaBetweenValues,
bool &newValue) {
2177 attributes.push_back(gangArgType);
2178 needCommaBetweenValues =
true;
2189 mlir::ArrayAttr &gangOnlyDeviceType) {
2194 bool needCommaBetweenValues =
false;
2195 bool needCommaBeforeOperands =
false;
2201 gangOnlyDeviceType =
2210 if (parser.parseAttribute(
2211 gangOnlyDeviceTypeAttributes.emplace_back()))
2218 needCommaBeforeOperands =
true;
2222 mlir::acc::GangArgType::Num);
2224 mlir::acc::GangArgType::Dim);
2226 parser.
getContext(), mlir::acc::GangArgType::Static);
2229 if (needCommaBeforeOperands) {
2230 needCommaBeforeOperands =
false;
2237 int32_t crtOperandsSize = gangOperands.size();
2239 bool newValue =
false;
2240 bool needValue =
false;
2241 if (needCommaBetweenValues) {
2249 gangOperands, gangOperandsType,
2250 gangArgTypeAttributes, argNum,
2251 needCommaBetweenValues, newValue)))
2254 gangOperands, gangOperandsType,
2255 gangArgTypeAttributes, argDim,
2256 needCommaBetweenValues, newValue)))
2258 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2259 gangOperands, gangOperandsType,
2260 gangArgTypeAttributes, argStatic,
2261 needCommaBetweenValues, newValue)))
2264 if (!newValue && needValue) {
2266 "new value expected after comma");
2274 if (gangOperands.empty())
2277 "expect at least one of num, dim or static values");
2283 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
2291 seg.push_back(gangOperands.size() - crtOperandsSize);
2299 gangArgTypeAttributes.end());
2304 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2313 std::optional<mlir::ArrayAttr> gangArgTypes,
2314 std::optional<mlir::ArrayAttr> deviceTypes,
2315 std::optional<mlir::DenseI32ArrayAttr> segments,
2316 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2318 if (operands.begin() == operands.end() &&
2333 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2335 llvm::interleaveComma(
2336 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2337 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2338 (*gangArgTypes)[opIdx]);
2339 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2340 p << LoopOp::getGangNumKeyword();
2341 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2342 p << LoopOp::getGangDimKeyword();
2343 else if (gangArgTypeAttr.getValue() ==
2344 mlir::acc::GangArgType::Static)
2345 p << LoopOp::getGangStaticKeyword();
2346 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
2357 std::optional<mlir::ArrayAttr> segments,
2358 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2361 for (
auto attr : *segments) {
2362 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2363 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2371 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2374 for (
auto attr : deviceTypes) {
2375 auto deviceTypeAttr =
2376 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2377 if (!deviceTypeAttr)
2379 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2386 if (getUpperbound().size() != getStep().size())
2387 return emitError() <<
"number of upperbounds expected to be the same as "
2390 if (getUpperbound().size() != getLowerbound().size())
2391 return emitError() <<
"number of upperbounds expected to be the same as "
2392 "number of lowerbounds";
2394 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2395 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2396 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
2397 <<
" as upperbound size";
2400 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2401 return emitOpError() <<
"collapse device_type attr must be define when"
2402 <<
" collapse attr is present";
2404 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2405 getCollapseAttr().getValue().size() !=
2406 getCollapseDeviceTypeAttr().getValue().size())
2407 return emitOpError() <<
"collapse attribute count must match collapse"
2408 <<
" device_type count";
2410 return emitOpError()
2411 <<
"duplicate device_type found in collapseDeviceType attribute";
2414 if (!getGangOperands().empty()) {
2415 if (!getGangOperandsArgType())
2416 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
2417 <<
" when gang operands are present";
2419 if (getGangOperands().size() !=
2420 getGangOperandsArgTypeAttr().getValue().size())
2421 return emitOpError() <<
"gangOperandsArgType attribute count must match"
2422 <<
" gangOperands count";
2425 return emitOpError() <<
"duplicate device_type found in gang attribute";
2428 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
2429 getGangOperandsDeviceTypeAttr(),
"gang")))
2434 return emitOpError() <<
"duplicate device_type found in worker attribute";
2436 return emitOpError() <<
"duplicate device_type found in "
2437 "workerNumOperandsDeviceType attribute";
2439 getWorkerNumOperandsDeviceTypeAttr(),
2445 return emitOpError() <<
"duplicate device_type found in vector attribute";
2447 return emitOpError() <<
"duplicate device_type found in "
2448 "vectorOperandsDeviceType attribute";
2450 getVectorOperandsDeviceTypeAttr(),
2455 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
2456 getTileOperandsDeviceTypeAttr(),
"tile")))
2460 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2464 return emitError() <<
"only one of \"" << acc::LoopOp::getAutoAttrStrName()
2465 <<
"\", " << getIndependentAttrName() <<
", "
2467 <<
" can be present at the same time";
2472 for (
auto attr : getSeqAttr()) {
2473 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2474 if (hasVector(deviceTypeAttr.getValue()) ||
2475 getVectorValue(deviceTypeAttr.getValue()) ||
2476 hasWorker(deviceTypeAttr.getValue()) ||
2477 getWorkerValue(deviceTypeAttr.getValue()) ||
2478 hasGang(deviceTypeAttr.getValue()) ||
2479 getGangValue(mlir::acc::GangArgType::Num,
2480 deviceTypeAttr.getValue()) ||
2481 getGangValue(mlir::acc::GangArgType::Dim,
2482 deviceTypeAttr.getValue()) ||
2483 getGangValue(mlir::acc::GangArgType::Static,
2484 deviceTypeAttr.getValue()))
2486 <<
"gang, worker or vector cannot appear with the seq attr";
2490 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2491 *
this, getPrivatizations(), getPrivateOperands(),
"private",
2492 "privatizations",
false)))
2495 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2496 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
2497 "reductions",
false)))
2500 if (getCombined().has_value() &&
2501 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2502 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2503 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2504 return emitError(
"unexpected combined constructs attribute");
2508 if (getRegion().empty())
2509 return emitError(
"expected non-empty body.");
2512 if (isContainerLike()) {
2515 uint64_t collapseCount = getCollapseValue().value_or(1);
2516 if (getCollapseAttr()) {
2517 for (
auto collapseEntry : getCollapseAttr()) {
2518 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
2519 if (intAttr.getValue().getZExtValue() > collapseCount)
2520 collapseCount = intAttr.getValue().getZExtValue();
2528 bool foundSibling =
false;
2530 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
2532 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
2534 foundSibling =
true;
2539 expectedParent = op;
2542 if (collapseCount == 0)
2548 return emitError(
"found sibling loops inside container-like acc.loop");
2549 if (collapseCount != 0)
2550 return emitError(
"failed to find enough loop-like operations inside "
2551 "container-like acc.loop");
2557 unsigned LoopOp::getNumDataOperands() {
2558 return getReductionOperands().size() + getPrivateOperands().size();
2561 Value LoopOp::getDataOperand(
unsigned i) {
2562 unsigned numOptional =
2563 getLowerbound().size() + getUpperbound().size() + getStep().size();
2564 numOptional += getGangOperands().size();
2565 numOptional += getVectorOperands().size();
2566 numOptional += getWorkerNumOperands().size();
2567 numOptional += getTileOperands().size();
2568 numOptional += getCacheOperands().size();
2569 return getOperand(numOptional + i);
2574 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2578 bool LoopOp::hasIndependent() {
2582 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2588 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2596 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2598 getVectorOperands(), deviceType);
2603 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2611 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2613 getWorkerNumOperands(), deviceType);
2618 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2627 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2629 getTileOperandsSegments(), deviceType);
2632 std::optional<int64_t> LoopOp::getCollapseValue() {
2636 std::optional<int64_t>
2637 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2638 if (!getCollapseAttr())
2639 return std::nullopt;
2640 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2642 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2643 return intAttr.getValue().getZExtValue();
2645 return std::nullopt;
2648 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2652 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2653 mlir::acc::DeviceType deviceType) {
2654 if (getGangOperands().empty())
2656 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
2657 int32_t nbOperandsBefore = 0;
2658 for (
unsigned i = 0; i < *pos; ++i)
2659 nbOperandsBefore += (*getGangOperandsSegments())[i];
2662 .drop_front(nbOperandsBefore)
2663 .take_front((*getGangOperandsSegments())[*pos]);
2665 int32_t argTypeIdx = nbOperandsBefore;
2666 for (
auto value : values) {
2667 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2668 (*getGangOperandsArgType())[argTypeIdx]);
2669 if (gangArgTypeAttr.getValue() == gangArgType)
2679 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2684 return {&getRegion()};
2728 if (!regionArgs.empty()) {
2729 p << acc::LoopOp::getControlKeyword() <<
"(";
2730 llvm::interleaveComma(regionArgs, p,
2732 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
2733 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
2734 <<
" : " << stepType <<
") ";
2741 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
2742 effectiveDeviceTypes));
2745 void acc::LoopOp::addIndependent(
2747 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
2748 context, getIndependentAttr(), effectiveDeviceTypes));
2753 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
2754 effectiveDeviceTypes));
2757 void acc::LoopOp::setCollapseForDeviceTypes(
2759 llvm::APInt value) {
2763 assert((getCollapseAttr() ==
nullptr) ==
2764 (getCollapseDeviceTypeAttr() ==
nullptr));
2765 assert(value.getBitWidth() == 64);
2767 if (getCollapseAttr()) {
2768 for (
const auto &existing :
2769 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
2770 newValues.push_back(std::get<0>(existing));
2771 newDeviceTypes.push_back(std::get<1>(existing));
2775 if (effectiveDeviceTypes.empty()) {
2778 newValues.push_back(
2780 newDeviceTypes.push_back(
2783 for (DeviceType DT : effectiveDeviceTypes) {
2784 newValues.push_back(
2791 setCollapseDeviceTypeAttr(
ArrayAttr::get(context, newDeviceTypes));
2794 void acc::LoopOp::setTileForDeviceTypes(
2798 if (getTileOperandsSegments())
2799 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
2801 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2802 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2803 getTileOperandsMutable(), segments));
2805 setTileOperandsSegments(segments);
2808 void acc::LoopOp::addVectorOperand(
2811 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2812 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2813 newValue, getVectorOperandsMutable()));
2816 void acc::LoopOp::addEmptyVector(
2818 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
2819 effectiveDeviceTypes));
2822 void acc::LoopOp::addWorkerNumOperand(
2825 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2826 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2827 newValue, getWorkerNumOperandsMutable()));
2830 void acc::LoopOp::addEmptyWorker(
2832 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
2833 effectiveDeviceTypes));
2836 void acc::LoopOp::addEmptyGang(
2838 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
2839 effectiveDeviceTypes));
2842 void acc::LoopOp::addGangOperands(
2847 getGangOperandsSegments())
2848 llvm::copy(*existingSegments, std::back_inserter(segments));
2850 unsigned beforeCount = segments.size();
2852 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2853 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2854 getGangOperandsMutable(), segments));
2856 setGangOperandsSegments(segments);
2863 unsigned numAdded = segments.size() - beforeCount;
2867 if (getGangOperandsArgTypeAttr())
2868 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
2870 for (
auto i : llvm::index_range(0u, numAdded)) {
2871 llvm::transform(argTypes, std::back_inserter(gangTypes),
2872 [=](mlir::acc::GangArgType gangTy) {
2890 if (getOperands().empty() && !getDefaultAttr())
2891 return emitError(
"at least one operand or the default attribute "
2892 "must appear on the data operation");
2894 for (
mlir::Value operand : getDataClauseOperands())
2895 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2896 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2897 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2898 operand.getDefiningOp()))
2899 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2902 if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*
this)))
2908 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
2910 Value DataOp::getDataOperand(
unsigned i) {
2911 unsigned numOptional = getIfCond() ? 1 : 0;
2913 numOptional += getWaitOperands().size();
2914 return getOperand(numOptional + i);
2917 bool acc::DataOp::hasAsyncOnly() {
2921 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2929 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2936 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2945 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2947 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2948 getHasWaitDevnum(), deviceType);
2955 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2957 getWaitOperandsSegments(), getHasWaitDevnum(),
2961 void acc::DataOp::addAsyncOnly(
2963 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2964 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2967 void acc::DataOp::addAsyncOperand(
2970 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2971 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2972 getAsyncOperandsMutable()));
2975 void acc::DataOp::addWaitOnly(
MLIRContext *context,
2977 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2978 effectiveDeviceTypes));
2981 void acc::DataOp::addWaitOperands(
2986 if (getWaitOperandsSegments())
2987 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2989 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2990 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2991 getWaitOperandsMutable(), segments));
2992 setWaitOperandsSegments(segments);
2995 if (getHasWaitDevnumAttr())
2996 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2999 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3012 if (getDataClauseOperands().empty())
3013 return emitError(
"at least one operand must be present in dataOperands on "
3014 "the exit data operation");
3018 if (getAsyncOperand() && getAsync())
3019 return emitError(
"async attribute cannot appear with asyncOperand");
3023 if (!getWaitOperands().empty() && getWait())
3024 return emitError(
"wait attribute cannot appear with waitOperands");
3026 if (getWaitDevnum() && getWaitOperands().empty())
3027 return emitError(
"wait_devnum cannot appear without waitOperands");
3032 unsigned ExitDataOp::getNumDataOperands() {
3033 return getDataClauseOperands().size();
3036 Value ExitDataOp::getDataOperand(
unsigned i) {
3037 unsigned numOptional = getIfCond() ? 1 : 0;
3038 numOptional += getAsyncOperand() ? 1 : 0;
3039 numOptional += getWaitDevnum() ? 1 : 0;
3040 return getOperand(getWaitOperands().size() + numOptional + i);
3045 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3056 if (getDataClauseOperands().empty())
3057 return emitError(
"at least one operand must be present in dataOperands on "
3058 "the enter data operation");
3062 if (getAsyncOperand() && getAsync())
3063 return emitError(
"async attribute cannot appear with asyncOperand");
3067 if (!getWaitOperands().empty() && getWait())
3068 return emitError(
"wait attribute cannot appear with waitOperands");
3070 if (getWaitDevnum() && getWaitOperands().empty())
3071 return emitError(
"wait_devnum cannot appear without waitOperands");
3073 for (
mlir::Value operand : getDataClauseOperands())
3074 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3075 operand.getDefiningOp()))
3076 return emitError(
"expect data entry operation as defining op");
3081 unsigned EnterDataOp::getNumDataOperands() {
3082 return getDataClauseOperands().size();
3085 Value EnterDataOp::getDataOperand(
unsigned i) {
3086 unsigned numOptional = getIfCond() ? 1 : 0;
3087 numOptional += getAsyncOperand() ? 1 : 0;
3088 numOptional += getWaitDevnum() ? 1 : 0;
3089 return getOperand(getWaitOperands().size() + numOptional + i);
3094 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
3113 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3120 if (
Value writeVal = op.getWriteOpVal()) {
3130 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3136 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3137 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3139 return dyn_cast<AtomicReadOp>(getSecondOp());
3142 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3143 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3145 return dyn_cast<AtomicWriteOp>(getSecondOp());
3148 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3149 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3151 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3154 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
3160 template <
typename Op>
3161 static LogicalResult
3163 bool requireAtLeastOneOperand =
true) {
3164 if (operands.empty() && requireAtLeastOneOperand)
3167 "at least one operand must appear on the declare operation");
3170 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3171 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3172 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3173 operand.getDefiningOp()))
3175 "expect valid declare data entry operation or acc.getdeviceptr "
3179 assert(var &&
"declare operands can only be data entry operations which "
3182 std::optional<mlir::acc::DataClause> dataClauseOptional{
3184 assert(dataClauseOptional.has_value() &&
3185 "declare operands can only be data entry operations which must have "
3187 (void)dataClauseOptional;
3221 acc::DeviceType dtype) {
3222 unsigned parallelism = 0;
3223 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
3224 parallelism += op.hasWorker(dtype) ? 1 : 0;
3225 parallelism += op.hasVector(dtype) ? 1 : 0;
3226 parallelism += op.hasSeq(dtype) ? 1 : 0;
3231 unsigned baseParallelism =
3234 if (baseParallelism > 1)
3235 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
3236 "be present at the same time";
3238 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3240 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
3245 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3246 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
3247 "be present at the same time";
3254 mlir::ArrayAttr &deviceTypes) {
3259 if (parser.parseAttribute(bindNameAttrs.emplace_back()))
3261 if (failed(parser.parseOptionalLSquare())) {
3262 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3263 parser.getContext(), mlir::acc::DeviceType::None));
3265 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
3266 parser.parseRSquare())
3274 deviceTypes =
ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
3280 std::optional<mlir::ArrayAttr> bindName,
3281 std::optional<mlir::ArrayAttr> deviceTypes) {
3282 llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
3283 [&](
const auto &pair) {
3284 p << std::get<0>(pair);
3290 mlir::ArrayAttr &gang,
3291 mlir::ArrayAttr &gangDim,
3292 mlir::ArrayAttr &gangDimDeviceTypes) {
3295 gangDimDeviceTypeAttrs;
3296 bool needCommaBeforeOperands =
false;
3309 if (parser.parseAttribute(gangAttrs.emplace_back()))
3316 needCommaBeforeOperands =
true;
3319 if (needCommaBeforeOperands && failed(parser.
parseComma()))
3323 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
3324 parser.parseColon() ||
3325 parser.parseAttribute(gangDimAttrs.emplace_back()))
3327 if (succeeded(parser.parseOptionalLSquare())) {
3328 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
3329 parser.parseRSquare())
3332 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3333 parser.getContext(), mlir::acc::DeviceType::None));
3339 if (failed(parser.parseRParen()))
3344 gangDimDeviceTypes =
3351 std::optional<mlir::ArrayAttr> gang,
3352 std::optional<mlir::ArrayAttr> gangDim,
3353 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
3356 gang->size() == 1) {
3357 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
3370 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
3371 [&](
const auto &pair) {
3372 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
3373 p << std::get<0>(pair);
3381 mlir::ArrayAttr &deviceTypes) {
3394 if (parser.parseAttribute(attributes.emplace_back()))
3408 std::optional<mlir::ArrayAttr> deviceTypes) {
3411 auto deviceTypeAttr =
3412 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
3422 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3430 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
3436 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
3442 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
3446 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
3450 std::optional<llvm::StringRef>
3451 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3453 return std::nullopt;
3454 if (
auto pos =
findSegment(*getBindNameDeviceType(), deviceType)) {
3455 auto attr = (*getBindName())[*pos];
3456 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3457 return stringAttr.getValue();
3459 return std::nullopt;
3464 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
3468 std::optional<int64_t> RoutineOp::getGangDimValue() {
3472 std::optional<int64_t>
3473 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
3475 return std::nullopt;
3476 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
3477 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
3478 return intAttr.getInt();
3480 return std::nullopt;
3491 return emitOpError(
"cannot be nested in a compute operation");
3495 void acc::InitOp::addDeviceType(
MLIRContext *context,
3496 mlir::acc::DeviceType deviceType) {
3498 if (getDeviceTypesAttr())
3499 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3513 return emitOpError(
"cannot be nested in a compute operation");
3517 void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
3518 mlir::acc::DeviceType deviceType) {
3520 if (getDeviceTypesAttr())
3521 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3535 return emitOpError(
"cannot be nested in a compute operation");
3536 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
3537 return emitOpError(
"at least one default_async, device_num, or device_type "
3538 "operand must appear");
3548 if (getDataClauseOperands().empty())
3549 return emitError(
"at least one value must be present in dataOperands");
3552 getAsyncOperandsDeviceTypeAttr(),
3557 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3558 getWaitOperandsDeviceTypeAttr(),
"wait")))
3561 if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*
this)))
3564 for (
mlir::Value operand : getDataClauseOperands())
3565 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3566 operand.getDefiningOp()))
3567 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3573 unsigned UpdateOp::getNumDataOperands() {
3574 return getDataClauseOperands().size();
3577 Value UpdateOp::getDataOperand(
unsigned i) {
3579 numOptional += getIfCond() ? 1 : 0;
3580 return getOperand(getWaitOperands().size() + numOptional + i);
3585 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
3588 bool UpdateOp::hasAsyncOnly() {
3592 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3600 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3610 bool UpdateOp::hasWaitOnly() {
3614 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3623 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3625 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3626 getHasWaitDevnum(), deviceType);
3633 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3635 getWaitOperandsSegments(), getHasWaitDevnum(),
3646 if (getAsyncOperand() && getAsync())
3647 return emitError(
"async attribute cannot appear with asyncOperand");
3649 if (getWaitDevnum() && getWaitOperands().empty())
3650 return emitError(
"wait_devnum cannot appear without waitOperands");
3655 #define GET_OP_CLASSES
3656 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3658 #define GET_ATTRDEF_CLASSES
3659 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3661 #define GET_TYPEDEF_CLASSES
3662 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3673 .Case<ACC_DATA_ENTRY_OPS>(
3674 [&](
auto entry) {
return entry.getVarPtr(); })
3675 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3676 [&](
auto exit) {
return exit.getVarPtr(); })
3694 [&](
auto entry) {
return entry.getVarType(); })
3695 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3696 [&](
auto exit) {
return exit.getVarType(); })
3706 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3707 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
3717 [&](
auto dataClause) {
return dataClause.getAccVar(); })
3726 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
3736 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
3738 dataClause.getBounds().begin(), dataClause.getBounds().end());
3750 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
3752 dataClause.getAsyncOperands().begin(),
3753 dataClause.getAsyncOperands().end());
3764 return dataClause.getAsyncOperandsDeviceTypeAttr();
3772 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
3779 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
3786 std::optional<mlir::acc::DataClause>
3791 .Case<ACC_DATA_ENTRY_OPS>(
3792 [&](
auto entry) {
return entry.getDataClause(); })
3800 [&](
auto entry) {
return entry.getImplicit(); })
3809 [&](
auto entry) {
return entry.getDataClauseOperands(); })
3811 return dataOperands;
3819 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
3821 return dataOperands;
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.