21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
28 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
36 static bool isScalarLikeType(
Type type) {
40 struct MemRefPointerLikeModel
41 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
44 return cast<MemRefType>(pointer).getElementType();
46 mlir::acc::VariableTypeCategory
49 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
50 return mappableTy.getTypeCategory(varPtr);
52 auto memrefTy = cast<MemRefType>(pointer);
53 if (!memrefTy.hasRank()) {
56 return mlir::acc::VariableTypeCategory::uncategorized;
59 if (memrefTy.getRank() == 0) {
60 if (isScalarLikeType(memrefTy.getElementType())) {
61 return mlir::acc::VariableTypeCategory::scalar;
65 return mlir::acc::VariableTypeCategory::uncategorized;
69 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
70 return mlir::acc::VariableTypeCategory::array;
74 struct LLVMPointerPointerLikeModel
75 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76 LLVM::LLVMPointerType> {
84 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
85 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
88 if (existingDeviceTypes)
89 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
91 if (newDeviceTypes.empty())
92 deviceTypes.push_back(
95 for (DeviceType DT : newDeviceTypes)
107 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
108 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
113 if (existingDeviceTypes)
114 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
116 if (newDeviceTypes.empty()) {
117 argCollection.
append(arguments);
118 segments.push_back(arguments.size());
119 deviceTypes.push_back(
123 for (DeviceType DT : newDeviceTypes) {
124 argCollection.
append(arguments);
125 segments.push_back(arguments.size());
133 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
134 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
138 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
139 newDeviceTypes, arguments,
140 argCollection, segments);
148 void OpenACCDialect::initialize() {
151 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
154 #define GET_ATTRDEF_LIST
155 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
158 #define GET_TYPEDEF_LIST
159 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
165 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
166 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
175 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
181 mlir::acc::DeviceType deviceType) {
185 for (
auto attr : *arrayAttr) {
186 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
187 if (deviceTypeAttr.getValue() == deviceType)
195 std::optional<mlir::ArrayAttr> deviceTypes) {
200 llvm::interleaveComma(*deviceTypes, p,
206 mlir::acc::DeviceType deviceType) {
207 unsigned segmentIdx = 0;
208 for (
auto attr : segments) {
209 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
210 if (deviceTypeAttr.getValue() == deviceType)
211 return std::make_optional(segmentIdx);
221 mlir::acc::DeviceType deviceType) {
223 return range.take_front(0);
224 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
225 int32_t nbOperandsBefore = 0;
226 for (
unsigned i = 0; i < *pos; ++i)
227 nbOperandsBefore += (*segments)[i];
228 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
230 return range.take_front(0);
237 std::optional<mlir::ArrayAttr> hasWaitDevnum,
238 mlir::acc::DeviceType deviceType) {
241 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
242 if (hasWaitDevnum->getValue()[*pos])
253 std::optional<mlir::ArrayAttr> hasWaitDevnum,
254 mlir::acc::DeviceType deviceType) {
259 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
260 if (hasWaitDevnum && *hasWaitDevnum) {
261 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
262 if (boolAttr.getValue())
263 return range.drop_front(1);
269 template <
typename Op>
271 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
273 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
278 op.hasAsyncOnly(dtype))
280 "asyncOnly attribute cannot appear with asyncOperand");
285 op.hasWaitOnly(dtype))
286 return op.
emitError(
"wait attribute cannot appear with waitOperands");
291 template <
typename Op>
294 return op.
emitError(
"must have var operand");
296 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
297 mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
302 return op.
emitError(
"var must be mappable or pointer-like (not both)");
305 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
306 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
307 return op.
emitError(
"var must be mappable or pointer-like");
309 if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
310 op.getVarType() != op.getVar().getType())
311 return op.
emitError(
"varType must match when var is mappable");
316 template <
typename Op>
318 if (op.getVar().getType() != op.getAccVar().getType())
319 return op.
emitError(
"input and output types must match");
341 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
362 if (failed(parser.
parseType(accVarType)))
372 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
384 mlir::TypeAttr &varTypeAttr) {
385 if (failed(parser.
parseType(varPtrType)))
401 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
403 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
412 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
420 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
421 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
423 if (typeToCheckAgainst != varType) {
434 auto extent = getExtent();
435 auto upperbound = getUpperbound();
436 if (!extent && !upperbound)
437 return emitError(
"expected extent or upperbound.");
447 "data clause associated with private operation must match its intent");
458 return emitError(
"data clause associated with firstprivate operation must "
470 return emitError(
"data clause associated with reduction operation must "
482 return emitError(
"data clause associated with deviceptr operation must "
497 "data clause associated with present operation must match its intent");
510 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
515 "data clause associated with copyin operation must match its intent"
516 " or specify original clause this operation was decomposed from");
524 bool acc::CopyinOp::isCopyinReadonly() {
525 return getDataClause() == acc::DataClause::acc_copyin_readonly;
538 "data clause associated with create operation must match its intent"
539 " or specify original clause this operation was decomposed from");
547 bool acc::CreateOp::isCreateZero() {
549 return getDataClause() == acc::DataClause::acc_create_zero ||
558 return emitError(
"data clause associated with no_create operation must "
573 "data clause associated with attach operation must match its intent");
586 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
587 return emitError(
"data clause associated with device_resident operation "
588 "must match its intent");
603 "data clause associated with link operation must match its intent");
621 "data clause associated with copyout operation must match its intent"
622 " or specify original clause this operation was decomposed from");
624 return emitError(
"must have both host and device pointers");
632 bool acc::CopyoutOp::isCopyoutZero() {
648 getDataClause() != acc::DataClause::acc_declare_device_resident &&
651 "data clause associated with delete operation must match its intent"
652 " or specify original clause this operation was decomposed from");
654 return emitError(
"must have device pointer");
666 "data clause associated with detach operation must match its intent"
667 " or specify original clause this operation was decomposed from");
669 return emitError(
"must have device pointer");
681 "data clause associated with host operation must match its intent"
682 " or specify original clause this operation was decomposed from");
684 return emitError(
"must have both host and device pointers");
699 "data clause associated with device operation must match its intent"
700 " or specify original clause this operation was decomposed from");
715 "data clause associated with use_device operation must match its intent"
716 " or specify original clause this operation was decomposed from");
732 "data clause associated with cache operation must match its intent"
733 " or specify original clause this operation was decomposed from");
741 template <
typename StructureOp>
743 unsigned nRegions = 1) {
746 for (
unsigned i = 0; i < nRegions; ++i)
747 regions.push_back(state.addRegion());
749 for (
Region *region : regions)
757 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
764 template <
typename OpTy>
768 LogicalResult matchAndRewrite(OpTy op,
771 Value ifCond = op.getIfCond();
775 IntegerAttr constAttr;
778 if (constAttr.getInt())
779 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
791 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
803 template <
typename OpTy>
804 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
807 LogicalResult matchAndRewrite(OpTy op,
810 Value ifCond = op.getIfCond();
814 IntegerAttr constAttr;
817 if (constAttr.getInt())
818 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
833 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
835 if (optional && region.
empty())
839 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
843 return op->
emitOpError() <<
"expects " << regionName
846 << regionType <<
" type";
849 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
850 if (yieldOp.getOperands().size() != 1 ||
851 yieldOp.getOperands().getTypes()[0] != type)
852 return op->
emitOpError() <<
"expects " << regionName
854 "yield a value of the "
855 << regionType <<
" type";
861 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
863 "privatization",
"init",
getType(),
867 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
877 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
879 "privatization",
"init",
getType(),
883 if (getCopyRegion().empty())
884 return emitOpError() <<
"expects non-empty copy region";
889 return emitOpError() <<
"expects copy region with two arguments of the "
890 "privatization type";
892 if (getDestroyRegion().empty())
896 "privatization",
"destroy",
907 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
913 if (getCombinerRegion().empty())
914 return emitOpError() <<
"expects non-empty combiner region";
916 Block &reductionBlock = getCombinerRegion().
front();
920 return emitOpError() <<
"expects combiner region with the first two "
921 <<
"arguments of the reduction type";
923 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
924 if (yieldOp.getOperands().size() != 1 ||
925 yieldOp.getOperands().getTypes()[0] !=
getType())
926 return emitOpError() <<
"expects combiner region to yield a value "
927 "of the reduction type";
943 if (parser.parseAttribute(attributes.emplace_back()) ||
944 parser.parseArrow() ||
945 parser.parseOperand(operands.emplace_back()) ||
946 parser.parseColonType(types.emplace_back()))
960 std::optional<mlir::ArrayAttr> attributes) {
961 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
962 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
963 << std::get<1>(it).getType();
972 template <
typename Op>
976 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
977 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
978 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
979 operand.getDefiningOp()))
981 "expect data entry/exit operation or acc.getdeviceptr "
986 template <
typename Op>
990 llvm::StringRef symbolName,
bool checkOperandType =
true) {
991 if (!operands.empty()) {
992 if (!attributes || attributes->size() != operands.size())
994 <<
"expected as many " << symbolName <<
" symbol reference as "
995 << operandName <<
" operands";
999 <<
"unexpected " << symbolName <<
" symbol reference";
1004 for (
auto args : llvm::zip(operands, *attributes)) {
1007 if (!set.insert(operand).second)
1009 << operandName <<
" operand appears more than once";
1012 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1013 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1016 <<
"expected symbol reference " << symbolRef <<
" to point to a "
1017 << operandName <<
" declaration";
1019 if (checkOperandType && decl.getType() && decl.getType() != varType)
1020 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
1021 <<
") to be the same type as " << operandName
1022 <<
" declaration (" << decl.getType() <<
")";
1028 unsigned ParallelOp::getNumDataOperands() {
1029 return getReductionOperands().size() + getPrivateOperands().size() +
1030 getFirstprivateOperands().size() + getDataClauseOperands().size();
1033 Value ParallelOp::getDataOperand(
unsigned i) {
1035 numOptional += getNumGangs().size();
1036 numOptional += getNumWorkers().size();
1037 numOptional += getVectorLength().size();
1038 numOptional += getIfCond() ? 1 : 0;
1039 numOptional += getSelfCond() ? 1 : 0;
1040 return getOperand(getWaitOperands().size() + numOptional + i);
1043 template <
typename Op>
1045 ArrayAttr deviceTypes,
1046 llvm::StringRef keyword) {
1047 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1048 return op.
emitOpError() << keyword <<
" operands count must match "
1049 << keyword <<
" device_type count";
1053 template <
typename Op>
1056 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1057 std::size_t numOperandsInSegments = 0;
1058 std::size_t nbOfSegments = 0;
1061 for (
auto segCount : segments.
asArrayRef()) {
1062 if (maxInSegment != 0 && segCount > maxInSegment)
1063 return op.
emitOpError() << keyword <<
" expects a maximum of "
1064 << maxInSegment <<
" values per segment";
1065 numOperandsInSegments += segCount;
1070 if ((numOperandsInSegments != operands.size()) ||
1071 (!deviceTypes && !operands.empty()))
1073 << keyword <<
" operand count does not match count in segments";
1074 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1076 << keyword <<
" segment count does not match device_type count";
1081 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1082 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
1083 "privatizations",
false)))
1085 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1086 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1087 "firstprivate",
"firstprivatizations",
false)))
1089 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1090 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1091 "reductions",
false)))
1095 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1096 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1100 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1101 getWaitOperandsDeviceTypeAttr(),
"wait")))
1105 getNumWorkersDeviceTypeAttr(),
1110 getVectorLengthDeviceTypeAttr(),
1115 getAsyncOperandsDeviceTypeAttr(),
1119 if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*
this)))
1122 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
1128 mlir::acc::DeviceType deviceType) {
1131 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1136 bool acc::ParallelOp::hasAsyncOnly() {
1140 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1148 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1153 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1158 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1163 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1168 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1170 getVectorLength(), deviceType);
1178 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1180 getNumGangsSegments(), deviceType);
1183 bool acc::ParallelOp::hasWaitOnly() {
1187 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1196 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1198 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1199 getHasWaitDevnum(), deviceType);
1206 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1208 getWaitOperandsSegments(), getHasWaitDevnum(),
1224 odsBuilder, odsState, asyncOperands,
nullptr,
1225 nullptr, waitOperands,
nullptr,
1227 nullptr, numGangs,
nullptr,
1228 nullptr, numWorkers,
1229 nullptr, vectorLength,
1230 nullptr, ifCond, selfCond,
1231 nullptr, reductionOperands,
nullptr,
1232 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
1233 nullptr, dataClauseOperands,
1237 void acc::ParallelOp::addNumWorkersOperand(
1240 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1241 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1242 getNumWorkersMutable()));
1244 void acc::ParallelOp::addVectorLengthOperand(
1247 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1248 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1249 getVectorLengthMutable()));
1252 void acc::ParallelOp::addAsyncOnly(
1254 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1255 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1258 void acc::ParallelOp::addAsyncOperand(
1261 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1262 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1263 getAsyncOperandsMutable()));
1266 void acc::ParallelOp::addNumGangsOperands(
1270 if (getNumGangsSegments())
1271 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1273 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1274 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1275 getNumGangsMutable(), segments));
1277 setNumGangsSegments(segments);
1279 void acc::ParallelOp::addWaitOnly(
1281 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1282 effectiveDeviceTypes));
1284 void acc::ParallelOp::addWaitOperands(
1289 if (getWaitOperandsSegments())
1290 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1292 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1293 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1294 getWaitOperandsMutable(), segments));
1295 setWaitOperandsSegments(segments);
1298 if (getHasWaitDevnumAttr())
1299 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1302 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1319 int32_t crtOperandsSize = operands.size();
1322 if (parser.parseOperand(operands.emplace_back()) ||
1323 parser.parseColonType(types.emplace_back()))
1328 seg.push_back(operands.size() - crtOperandsSize);
1352 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1354 p <<
" [" << attr <<
"]";
1359 std::optional<mlir::ArrayAttr> deviceTypes,
1360 std::optional<mlir::DenseI32ArrayAttr> segments) {
1362 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1364 llvm::interleaveComma(
1365 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1366 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1386 int32_t crtOperandsSize = operands.size();
1390 if (parser.parseOperand(operands.emplace_back()) ||
1391 parser.parseColonType(types.emplace_back()))
1397 seg.push_back(operands.size() - crtOperandsSize);
1423 std::optional<mlir::DenseI32ArrayAttr> segments) {
1425 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1427 llvm::interleaveComma(
1428 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1429 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1442 mlir::ArrayAttr &keywordOnly) {
1446 bool needCommaBeforeOperands =
false;
1459 if (parser.parseAttribute(keywordAttrs.emplace_back()))
1466 needCommaBeforeOperands =
true;
1469 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1476 int32_t crtOperandsSize = operands.size();
1488 if (parser.parseOperand(operands.emplace_back()) ||
1489 parser.parseColonType(types.emplace_back()))
1495 seg.push_back(operands.size() - crtOperandsSize);
1524 if (attrs->size() != 1)
1526 if (
auto deviceTypeAttr =
1527 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1534 std::optional<mlir::ArrayAttr> deviceTypes,
1535 std::optional<mlir::DenseI32ArrayAttr> segments,
1536 std::optional<mlir::ArrayAttr> hasDevNum,
1537 std::optional<mlir::ArrayAttr> keywordOnly) {
1550 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1552 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1553 if (boolAttr && boolAttr.getValue())
1555 llvm::interleaveComma(
1556 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1557 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1574 if (parser.parseOperand(operands.emplace_back()) ||
1575 parser.parseColonType(types.emplace_back()))
1577 if (succeeded(parser.parseOptionalLSquare())) {
1578 if (parser.parseAttribute(attributes.emplace_back()) ||
1579 parser.parseRSquare())
1582 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1583 parser.getContext(), mlir::acc::DeviceType::None));
1597 std::optional<mlir::ArrayAttr> deviceTypes) {
1600 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
1601 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
1610 mlir::ArrayAttr &keywordOnlyDeviceType) {
1613 bool needCommaBeforeOperands =
false;
1619 keywordOnlyDeviceType =
1628 if (parser.parseAttribute(
1629 keywordOnlyDeviceTypeAttributes.emplace_back()))
1636 needCommaBeforeOperands =
true;
1639 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1644 if (parser.parseOperand(operands.emplace_back()) ||
1645 parser.parseColonType(types.emplace_back()))
1647 if (succeeded(parser.parseOptionalLSquare())) {
1648 if (parser.parseAttribute(attributes.emplace_back()) ||
1649 parser.parseRSquare())
1652 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1653 parser.getContext(), mlir::acc::DeviceType::None));
1659 if (failed(parser.parseRParen()))
1671 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1673 if (operands.begin() == operands.end() &&
1689 std::optional<OpAsmParser::UnresolvedOperand> &operand,
1690 mlir::Type &operandType, mlir::UnitAttr &attr) {
1703 if (failed(parser.
parseType(operandType)))
1713 std::optional<mlir::Value> operand,
1715 mlir::UnitAttr attr) {
1737 if (parser.parseOperand(operands.emplace_back()))
1745 if (parser.parseType(types.emplace_back()))
1760 mlir::UnitAttr attr) {
1765 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
1767 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
1773 mlir::acc::CombinedConstructsTypeAttr &attr) {
1776 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1779 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1782 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1785 "expected compute construct name");
1793 mlir::acc::CombinedConstructsTypeAttr attr) {
1795 switch (attr.getValue()) {
1796 case mlir::acc::CombinedConstructsType::KernelsLoop:
1799 case mlir::acc::CombinedConstructsType::ParallelLoop:
1802 case mlir::acc::CombinedConstructsType::SerialLoop:
1813 unsigned SerialOp::getNumDataOperands() {
1814 return getReductionOperands().size() + getPrivateOperands().size() +
1815 getFirstprivateOperands().size() + getDataClauseOperands().size();
1818 Value SerialOp::getDataOperand(
unsigned i) {
1820 numOptional += getIfCond() ? 1 : 0;
1821 numOptional += getSelfCond() ? 1 : 0;
1822 return getOperand(getWaitOperands().size() + numOptional + i);
1825 bool acc::SerialOp::hasAsyncOnly() {
1829 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1837 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1842 bool acc::SerialOp::hasWaitOnly() {
1846 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1855 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1857 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1858 getHasWaitDevnum(), deviceType);
1865 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1867 getWaitOperandsSegments(), getHasWaitDevnum(),
1872 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1873 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
1874 "privatizations",
false)))
1876 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1877 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1878 "firstprivate",
"firstprivatizations",
false)))
1880 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1881 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1882 "reductions",
false)))
1886 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1887 getWaitOperandsDeviceTypeAttr(),
"wait")))
1891 getAsyncOperandsDeviceTypeAttr(),
1895 if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*
this)))
1898 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
1901 void acc::SerialOp::addAsyncOnly(
1903 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1904 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1907 void acc::SerialOp::addAsyncOperand(
1910 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1911 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1912 getAsyncOperandsMutable()));
1915 void acc::SerialOp::addWaitOnly(
1917 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1918 effectiveDeviceTypes));
1920 void acc::SerialOp::addWaitOperands(
1925 if (getWaitOperandsSegments())
1926 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1928 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1929 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1930 getWaitOperandsMutable(), segments));
1931 setWaitOperandsSegments(segments);
1934 if (getHasWaitDevnumAttr())
1935 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1938 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1947 unsigned KernelsOp::getNumDataOperands() {
1948 return getDataClauseOperands().size();
1951 Value KernelsOp::getDataOperand(
unsigned i) {
1953 numOptional += getWaitOperands().size();
1954 numOptional += getNumGangs().size();
1955 numOptional += getNumWorkers().size();
1956 numOptional += getVectorLength().size();
1957 numOptional += getIfCond() ? 1 : 0;
1958 numOptional += getSelfCond() ? 1 : 0;
1959 return getOperand(numOptional + i);
1962 bool acc::KernelsOp::hasAsyncOnly() {
1966 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1974 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1979 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1984 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1989 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1994 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1996 getVectorLength(), deviceType);
2004 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2006 getNumGangsSegments(), deviceType);
2009 bool acc::KernelsOp::hasWaitOnly() {
2013 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2022 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2024 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2025 getHasWaitDevnum(), deviceType);
2032 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2034 getWaitOperandsSegments(), getHasWaitDevnum(),
2040 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2041 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2045 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2046 getWaitOperandsDeviceTypeAttr(),
"wait")))
2050 getNumWorkersDeviceTypeAttr(),
2055 getVectorLengthDeviceTypeAttr(),
2060 getAsyncOperandsDeviceTypeAttr(),
2064 if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*
this)))
2067 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
2070 void acc::KernelsOp::addNumWorkersOperand(
2073 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2074 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2075 getNumWorkersMutable()));
2078 void acc::KernelsOp::addVectorLengthOperand(
2081 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2082 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2083 getVectorLengthMutable()));
2085 void acc::KernelsOp::addAsyncOnly(
2087 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2088 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2091 void acc::KernelsOp::addAsyncOperand(
2094 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2095 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2096 getAsyncOperandsMutable()));
2099 void acc::KernelsOp::addNumGangsOperands(
2103 if (getNumGangsSegmentsAttr())
2104 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2106 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2107 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2108 getNumGangsMutable(), segments));
2110 setNumGangsSegments(segments);
2113 void acc::KernelsOp::addWaitOnly(
2115 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2116 effectiveDeviceTypes));
2118 void acc::KernelsOp::addWaitOperands(
2123 if (getWaitOperandsSegments())
2124 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2126 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2127 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2128 getWaitOperandsMutable(), segments));
2129 setWaitOperandsSegments(segments);
2132 if (getHasWaitDevnumAttr())
2133 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2136 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2146 if (getDataClauseOperands().empty())
2147 return emitError(
"at least one operand must appear on the host_data "
2150 for (
mlir::Value operand : getDataClauseOperands())
2151 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2152 return emitError(
"expect data entry operation as defining op");
2158 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2170 bool &needCommaBetweenValues,
bool &newValue) {
2177 attributes.push_back(gangArgType);
2178 needCommaBetweenValues =
true;
2189 mlir::ArrayAttr &gangOnlyDeviceType) {
2194 bool needCommaBetweenValues =
false;
2195 bool needCommaBeforeOperands =
false;
2201 gangOnlyDeviceType =
2210 if (parser.parseAttribute(
2211 gangOnlyDeviceTypeAttributes.emplace_back()))
2218 needCommaBeforeOperands =
true;
2222 mlir::acc::GangArgType::Num);
2224 mlir::acc::GangArgType::Dim);
2226 parser.
getContext(), mlir::acc::GangArgType::Static);
2229 if (needCommaBeforeOperands) {
2230 needCommaBeforeOperands =
false;
2237 int32_t crtOperandsSize = gangOperands.size();
2239 bool newValue =
false;
2240 bool needValue =
false;
2241 if (needCommaBetweenValues) {
2249 gangOperands, gangOperandsType,
2250 gangArgTypeAttributes, argNum,
2251 needCommaBetweenValues, newValue)))
2254 gangOperands, gangOperandsType,
2255 gangArgTypeAttributes, argDim,
2256 needCommaBetweenValues, newValue)))
2258 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2259 gangOperands, gangOperandsType,
2260 gangArgTypeAttributes, argStatic,
2261 needCommaBetweenValues, newValue)))
2264 if (!newValue && needValue) {
2266 "new value expected after comma");
2274 if (gangOperands.empty())
2277 "expect at least one of num, dim or static values");
2283 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
2291 seg.push_back(gangOperands.size() - crtOperandsSize);
2299 gangArgTypeAttributes.end());
2304 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2313 std::optional<mlir::ArrayAttr> gangArgTypes,
2314 std::optional<mlir::ArrayAttr> deviceTypes,
2315 std::optional<mlir::DenseI32ArrayAttr> segments,
2316 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2318 if (operands.begin() == operands.end() &&
2333 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2335 llvm::interleaveComma(
2336 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2337 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2338 (*gangArgTypes)[opIdx]);
2339 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2340 p << LoopOp::getGangNumKeyword();
2341 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2342 p << LoopOp::getGangDimKeyword();
2343 else if (gangArgTypeAttr.getValue() ==
2344 mlir::acc::GangArgType::Static)
2345 p << LoopOp::getGangStaticKeyword();
2346 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
2357 std::optional<mlir::ArrayAttr> segments,
2358 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2361 for (
auto attr : *segments) {
2362 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2363 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2371 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2374 for (
auto attr : deviceTypes) {
2375 auto deviceTypeAttr =
2376 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2377 if (!deviceTypeAttr)
2379 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2386 if (getUpperbound().size() != getStep().size())
2387 return emitError() <<
"number of upperbounds expected to be the same as "
2390 if (getUpperbound().size() != getLowerbound().size())
2391 return emitError() <<
"number of upperbounds expected to be the same as "
2392 "number of lowerbounds";
2394 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2395 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2396 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
2397 <<
" as upperbound size";
2400 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2401 return emitOpError() <<
"collapse device_type attr must be define when"
2402 <<
" collapse attr is present";
2404 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2405 getCollapseAttr().getValue().size() !=
2406 getCollapseDeviceTypeAttr().getValue().size())
2407 return emitOpError() <<
"collapse attribute count must match collapse"
2408 <<
" device_type count";
2410 return emitOpError()
2411 <<
"duplicate device_type found in collapseDeviceType attribute";
2414 if (!getGangOperands().empty()) {
2415 if (!getGangOperandsArgType())
2416 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
2417 <<
" when gang operands are present";
2419 if (getGangOperands().size() !=
2420 getGangOperandsArgTypeAttr().getValue().size())
2421 return emitOpError() <<
"gangOperandsArgType attribute count must match"
2422 <<
" gangOperands count";
2425 return emitOpError() <<
"duplicate device_type found in gang attribute";
2428 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
2429 getGangOperandsDeviceTypeAttr(),
"gang")))
2434 return emitOpError() <<
"duplicate device_type found in worker attribute";
2436 return emitOpError() <<
"duplicate device_type found in "
2437 "workerNumOperandsDeviceType attribute";
2439 getWorkerNumOperandsDeviceTypeAttr(),
2445 return emitOpError() <<
"duplicate device_type found in vector attribute";
2447 return emitOpError() <<
"duplicate device_type found in "
2448 "vectorOperandsDeviceType attribute";
2450 getVectorOperandsDeviceTypeAttr(),
2455 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
2456 getTileOperandsDeviceTypeAttr(),
"tile")))
2460 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2464 return emitError() <<
"only one of auto, independent, seq can be present "
2470 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
2473 bool hasDefaultSeq =
2475 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2478 bool hasDefaultIndependent =
2479 getIndependentAttr()
2481 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2484 bool hasDefaultAuto =
2486 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2489 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
2491 <<
"at least one of auto, independent, seq must be present";
2496 for (
auto attr : getSeqAttr()) {
2497 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2498 if (hasVector(deviceTypeAttr.getValue()) ||
2499 getVectorValue(deviceTypeAttr.getValue()) ||
2500 hasWorker(deviceTypeAttr.getValue()) ||
2501 getWorkerValue(deviceTypeAttr.getValue()) ||
2502 hasGang(deviceTypeAttr.getValue()) ||
2503 getGangValue(mlir::acc::GangArgType::Num,
2504 deviceTypeAttr.getValue()) ||
2505 getGangValue(mlir::acc::GangArgType::Dim,
2506 deviceTypeAttr.getValue()) ||
2507 getGangValue(mlir::acc::GangArgType::Static,
2508 deviceTypeAttr.getValue()))
2509 return emitError() <<
"gang, worker or vector cannot appear with seq";
2513 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2514 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
2515 "privatizations",
false)))
2518 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2519 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
2520 "reductions",
false)))
2523 if (getCombined().has_value() &&
2524 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2525 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2526 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2527 return emitError(
"unexpected combined constructs attribute");
2531 if (getRegion().empty())
2532 return emitError(
"expected non-empty body.");
2535 if (isContainerLike()) {
2538 uint64_t collapseCount = getCollapseValue().value_or(1);
2539 if (getCollapseAttr()) {
2540 for (
auto collapseEntry : getCollapseAttr()) {
2541 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
2542 if (intAttr.getValue().getZExtValue() > collapseCount)
2543 collapseCount = intAttr.getValue().getZExtValue();
2551 bool foundSibling =
false;
2553 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
2555 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
2557 foundSibling =
true;
2562 expectedParent = op;
2565 if (collapseCount == 0)
2571 return emitError(
"found sibling loops inside container-like acc.loop");
2572 if (collapseCount != 0)
2573 return emitError(
"failed to find enough loop-like operations inside "
2574 "container-like acc.loop");
2580 unsigned LoopOp::getNumDataOperands() {
2581 return getReductionOperands().size() + getPrivateOperands().size();
2584 Value LoopOp::getDataOperand(
unsigned i) {
2585 unsigned numOptional =
2586 getLowerbound().size() + getUpperbound().size() + getStep().size();
2587 numOptional += getGangOperands().size();
2588 numOptional += getVectorOperands().size();
2589 numOptional += getWorkerNumOperands().size();
2590 numOptional += getTileOperands().size();
2591 numOptional += getCacheOperands().size();
2592 return getOperand(numOptional + i);
2597 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2601 bool LoopOp::hasIndependent() {
2605 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2611 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2619 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2621 getVectorOperands(), deviceType);
2626 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2634 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2636 getWorkerNumOperands(), deviceType);
2641 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2650 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2652 getTileOperandsSegments(), deviceType);
2655 std::optional<int64_t> LoopOp::getCollapseValue() {
2659 std::optional<int64_t>
2660 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2661 if (!getCollapseAttr())
2662 return std::nullopt;
2663 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2665 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2666 return intAttr.getValue().getZExtValue();
2668 return std::nullopt;
2671 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2675 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2676 mlir::acc::DeviceType deviceType) {
2677 if (getGangOperands().empty())
2679 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
2680 int32_t nbOperandsBefore = 0;
2681 for (
unsigned i = 0; i < *pos; ++i)
2682 nbOperandsBefore += (*getGangOperandsSegments())[i];
2685 .drop_front(nbOperandsBefore)
2686 .take_front((*getGangOperandsSegments())[*pos]);
2688 int32_t argTypeIdx = nbOperandsBefore;
2689 for (
auto value : values) {
2690 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2691 (*getGangOperandsArgType())[argTypeIdx]);
2692 if (gangArgTypeAttr.getValue() == gangArgType)
2702 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2707 return {&getRegion()};
2751 if (!regionArgs.empty()) {
2752 p << acc::LoopOp::getControlKeyword() <<
"(";
2753 llvm::interleaveComma(regionArgs, p,
2755 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
2756 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
2757 <<
" : " << stepType <<
") ";
2764 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
2765 effectiveDeviceTypes));
2768 void acc::LoopOp::addIndependent(
2770 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
2771 context, getIndependentAttr(), effectiveDeviceTypes));
2776 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
2777 effectiveDeviceTypes));
2780 void acc::LoopOp::setCollapseForDeviceTypes(
2782 llvm::APInt value) {
2786 assert((getCollapseAttr() ==
nullptr) ==
2787 (getCollapseDeviceTypeAttr() ==
nullptr));
2788 assert(value.getBitWidth() == 64);
2790 if (getCollapseAttr()) {
2791 for (
const auto &existing :
2792 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
2793 newValues.push_back(std::get<0>(existing));
2794 newDeviceTypes.push_back(std::get<1>(existing));
2798 if (effectiveDeviceTypes.empty()) {
2801 newValues.push_back(
2803 newDeviceTypes.push_back(
2806 for (DeviceType DT : effectiveDeviceTypes) {
2807 newValues.push_back(
2814 setCollapseDeviceTypeAttr(
ArrayAttr::get(context, newDeviceTypes));
2817 void acc::LoopOp::setTileForDeviceTypes(
2821 if (getTileOperandsSegments())
2822 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
2824 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2825 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2826 getTileOperandsMutable(), segments));
2828 setTileOperandsSegments(segments);
2831 void acc::LoopOp::addVectorOperand(
2834 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2835 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2836 newValue, getVectorOperandsMutable()));
2839 void acc::LoopOp::addEmptyVector(
2841 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
2842 effectiveDeviceTypes));
2845 void acc::LoopOp::addWorkerNumOperand(
2848 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2849 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2850 newValue, getWorkerNumOperandsMutable()));
2853 void acc::LoopOp::addEmptyWorker(
2855 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
2856 effectiveDeviceTypes));
2859 void acc::LoopOp::addEmptyGang(
2861 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
2862 effectiveDeviceTypes));
2865 bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
2866 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
2867 return attr.getValue() == dt;
2869 auto testFromArr = [=](ArrayAttr arr) ->
bool {
2870 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
2873 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
2875 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
2877 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
2883 bool acc::LoopOp::hasDefaultGangWorkerVector() {
2884 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
2885 hasGang() || getGangValue(GangArgType::Num) ||
2886 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
2889 void acc::LoopOp::addGangOperands(
2894 getGangOperandsSegments())
2895 llvm::copy(*existingSegments, std::back_inserter(segments));
2897 unsigned beforeCount = segments.size();
2899 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2900 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2901 getGangOperandsMutable(), segments));
2903 setGangOperandsSegments(segments);
2910 unsigned numAdded = segments.size() - beforeCount;
2914 if (getGangOperandsArgTypeAttr())
2915 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
2917 for (
auto i : llvm::index_range(0u, numAdded)) {
2918 llvm::transform(argTypes, std::back_inserter(gangTypes),
2919 [=](mlir::acc::GangArgType gangTy) {
2937 if (getOperands().empty() && !getDefaultAttr())
2938 return emitError(
"at least one operand or the default attribute "
2939 "must appear on the data operation");
2941 for (
mlir::Value operand : getDataClauseOperands())
2942 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2943 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2944 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2945 operand.getDefiningOp()))
2946 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2949 if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*
this)))
2955 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
2957 Value DataOp::getDataOperand(
unsigned i) {
2958 unsigned numOptional = getIfCond() ? 1 : 0;
2960 numOptional += getWaitOperands().size();
2961 return getOperand(numOptional + i);
2964 bool acc::DataOp::hasAsyncOnly() {
2968 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2976 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2983 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2992 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2994 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2995 getHasWaitDevnum(), deviceType);
3002 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3004 getWaitOperandsSegments(), getHasWaitDevnum(),
3008 void acc::DataOp::addAsyncOnly(
3010 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3011 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3014 void acc::DataOp::addAsyncOperand(
3017 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3018 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3019 getAsyncOperandsMutable()));
3022 void acc::DataOp::addWaitOnly(
MLIRContext *context,
3024 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3025 effectiveDeviceTypes));
3028 void acc::DataOp::addWaitOperands(
3033 if (getWaitOperandsSegments())
3034 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3036 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3037 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3038 getWaitOperandsMutable(), segments));
3039 setWaitOperandsSegments(segments);
3042 if (getHasWaitDevnumAttr())
3043 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3046 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3059 if (getDataClauseOperands().empty())
3060 return emitError(
"at least one operand must be present in dataOperands on "
3061 "the exit data operation");
3065 if (getAsyncOperand() && getAsync())
3066 return emitError(
"async attribute cannot appear with asyncOperand");
3070 if (!getWaitOperands().empty() && getWait())
3071 return emitError(
"wait attribute cannot appear with waitOperands");
3073 if (getWaitDevnum() && getWaitOperands().empty())
3074 return emitError(
"wait_devnum cannot appear without waitOperands");
3079 unsigned ExitDataOp::getNumDataOperands() {
3080 return getDataClauseOperands().size();
3083 Value ExitDataOp::getDataOperand(
unsigned i) {
3084 unsigned numOptional = getIfCond() ? 1 : 0;
3085 numOptional += getAsyncOperand() ? 1 : 0;
3086 numOptional += getWaitDevnum() ? 1 : 0;
3087 return getOperand(getWaitOperands().size() + numOptional + i);
3092 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3103 if (getDataClauseOperands().empty())
3104 return emitError(
"at least one operand must be present in dataOperands on "
3105 "the enter data operation");
3109 if (getAsyncOperand() && getAsync())
3110 return emitError(
"async attribute cannot appear with asyncOperand");
3114 if (!getWaitOperands().empty() && getWait())
3115 return emitError(
"wait attribute cannot appear with waitOperands");
3117 if (getWaitDevnum() && getWaitOperands().empty())
3118 return emitError(
"wait_devnum cannot appear without waitOperands");
3120 for (
mlir::Value operand : getDataClauseOperands())
3121 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3122 operand.getDefiningOp()))
3123 return emitError(
"expect data entry operation as defining op");
3128 unsigned EnterDataOp::getNumDataOperands() {
3129 return getDataClauseOperands().size();
3132 Value EnterDataOp::getDataOperand(
unsigned i) {
3133 unsigned numOptional = getIfCond() ? 1 : 0;
3134 numOptional += getAsyncOperand() ? 1 : 0;
3135 numOptional += getWaitDevnum() ? 1 : 0;
3136 return getOperand(getWaitOperands().size() + numOptional + i);
3141 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
3160 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3167 if (
Value writeVal = op.getWriteOpVal()) {
3177 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3183 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3184 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3186 return dyn_cast<AtomicReadOp>(getSecondOp());
3189 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3190 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3192 return dyn_cast<AtomicWriteOp>(getSecondOp());
3195 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3196 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3198 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3201 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
3207 template <
typename Op>
3208 static LogicalResult
3210 bool requireAtLeastOneOperand =
true) {
3211 if (operands.empty() && requireAtLeastOneOperand)
3214 "at least one operand must appear on the declare operation");
3217 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3218 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3219 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3220 operand.getDefiningOp()))
3222 "expect valid declare data entry operation or acc.getdeviceptr "
3226 assert(var &&
"declare operands can only be data entry operations which "
3229 std::optional<mlir::acc::DataClause> dataClauseOptional{
3231 assert(dataClauseOptional.has_value() &&
3232 "declare operands can only be data entry operations which must have "
3234 (void)dataClauseOptional;
3268 acc::DeviceType dtype) {
3269 unsigned parallelism = 0;
3270 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
3271 parallelism += op.hasWorker(dtype) ? 1 : 0;
3272 parallelism += op.hasVector(dtype) ? 1 : 0;
3273 parallelism += op.hasSeq(dtype) ? 1 : 0;
3278 unsigned baseParallelism =
3281 if (baseParallelism > 1)
3282 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
3283 "be present at the same time";
3285 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3287 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
3292 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3293 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
3294 "be present at the same time";
3301 mlir::ArrayAttr &deviceTypes) {
3306 if (parser.parseAttribute(bindNameAttrs.emplace_back()))
3308 if (failed(parser.parseOptionalLSquare())) {
3309 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3310 parser.getContext(), mlir::acc::DeviceType::None));
3312 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
3313 parser.parseRSquare())
3321 deviceTypes =
ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
3327 std::optional<mlir::ArrayAttr> bindName,
3328 std::optional<mlir::ArrayAttr> deviceTypes) {
3329 llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
3330 [&](
const auto &pair) {
3331 p << std::get<0>(pair);
3337 mlir::ArrayAttr &gang,
3338 mlir::ArrayAttr &gangDim,
3339 mlir::ArrayAttr &gangDimDeviceTypes) {
3342 gangDimDeviceTypeAttrs;
3343 bool needCommaBeforeOperands =
false;
3356 if (parser.parseAttribute(gangAttrs.emplace_back()))
3363 needCommaBeforeOperands =
true;
3366 if (needCommaBeforeOperands && failed(parser.
parseComma()))
3370 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
3371 parser.parseColon() ||
3372 parser.parseAttribute(gangDimAttrs.emplace_back()))
3374 if (succeeded(parser.parseOptionalLSquare())) {
3375 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
3376 parser.parseRSquare())
3379 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3380 parser.getContext(), mlir::acc::DeviceType::None));
3386 if (failed(parser.parseRParen()))
3391 gangDimDeviceTypes =
3398 std::optional<mlir::ArrayAttr> gang,
3399 std::optional<mlir::ArrayAttr> gangDim,
3400 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
3403 gang->size() == 1) {
3404 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
3417 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
3418 [&](
const auto &pair) {
3419 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
3420 p << std::get<0>(pair);
3428 mlir::ArrayAttr &deviceTypes) {
3441 if (parser.parseAttribute(attributes.emplace_back()))
3455 std::optional<mlir::ArrayAttr> deviceTypes) {
3458 auto deviceTypeAttr =
3459 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
3469 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3477 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
3483 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
3489 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
3493 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
3497 std::optional<llvm::StringRef>
3498 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3500 return std::nullopt;
3501 if (
auto pos =
findSegment(*getBindNameDeviceType(), deviceType)) {
3502 auto attr = (*getBindName())[*pos];
3503 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3504 return stringAttr.getValue();
3506 return std::nullopt;
3511 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
3515 std::optional<int64_t> RoutineOp::getGangDimValue() {
3519 std::optional<int64_t>
3520 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
3522 return std::nullopt;
3523 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
3524 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
3525 return intAttr.getInt();
3527 return std::nullopt;
3538 return emitOpError(
"cannot be nested in a compute operation");
3542 void acc::InitOp::addDeviceType(
MLIRContext *context,
3543 mlir::acc::DeviceType deviceType) {
3545 if (getDeviceTypesAttr())
3546 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3560 return emitOpError(
"cannot be nested in a compute operation");
3564 void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
3565 mlir::acc::DeviceType deviceType) {
3567 if (getDeviceTypesAttr())
3568 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3582 return emitOpError(
"cannot be nested in a compute operation");
3583 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
3584 return emitOpError(
"at least one default_async, device_num, or device_type "
3585 "operand must appear");
3595 if (getDataClauseOperands().empty())
3596 return emitError(
"at least one value must be present in dataOperands");
3599 getAsyncOperandsDeviceTypeAttr(),
3604 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3605 getWaitOperandsDeviceTypeAttr(),
"wait")))
3608 if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*
this)))
3611 for (
mlir::Value operand : getDataClauseOperands())
3612 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3613 operand.getDefiningOp()))
3614 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3620 unsigned UpdateOp::getNumDataOperands() {
3621 return getDataClauseOperands().size();
3624 Value UpdateOp::getDataOperand(
unsigned i) {
3626 numOptional += getIfCond() ? 1 : 0;
3627 return getOperand(getWaitOperands().size() + numOptional + i);
3632 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
3635 bool UpdateOp::hasAsyncOnly() {
3639 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3647 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3657 bool UpdateOp::hasWaitOnly() {
3661 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3670 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3672 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3673 getHasWaitDevnum(), deviceType);
3680 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3682 getWaitOperandsSegments(), getHasWaitDevnum(),
3693 if (getAsyncOperand() && getAsync())
3694 return emitError(
"async attribute cannot appear with asyncOperand");
3696 if (getWaitDevnum() && getWaitOperands().empty())
3697 return emitError(
"wait_devnum cannot appear without waitOperands");
3702 #define GET_OP_CLASSES
3703 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3705 #define GET_ATTRDEF_CLASSES
3706 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3708 #define GET_TYPEDEF_CLASSES
3709 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3720 .Case<ACC_DATA_ENTRY_OPS>(
3721 [&](
auto entry) {
return entry.getVarPtr(); })
3722 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3723 [&](
auto exit) {
return exit.getVarPtr(); })
3741 [&](
auto entry) {
return entry.getVarType(); })
3742 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3743 [&](
auto exit) {
return exit.getVarType(); })
3753 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3754 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
3764 [&](
auto dataClause) {
return dataClause.getAccVar(); })
3773 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
3783 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
3785 dataClause.getBounds().begin(), dataClause.getBounds().end());
3797 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
3799 dataClause.getAsyncOperands().begin(),
3800 dataClause.getAsyncOperands().end());
3811 return dataClause.getAsyncOperandsDeviceTypeAttr();
3819 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
3826 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
3833 std::optional<mlir::acc::DataClause>
3838 .Case<ACC_DATA_ENTRY_OPS>(
3839 [&](
auto entry) {
return entry.getDataClause(); })
3847 [&](
auto entry) {
return entry.getImplicit(); })
3856 [&](
auto entry) {
return entry.getDataClauseOperands(); })
3858 return dataOperands;
3866 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
3868 return dataOperands;
3874 if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) {
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Operation * getParentOp()
Return the parent operation this region is attached to.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
mlir::Operation * getEnclosingComputeOp(mlir::Region ®ion)
Used to obtain the enclosing compute construct operation that contains the provided region.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.