19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/TypeSwitch.h"
25 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
26 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
27 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
28 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
31 struct MemRefPointerLikeModel
32 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
35 return llvm::cast<MemRefType>(pointer).getElementType();
39 struct LLVMPointerPointerLikeModel
40 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
41 LLVM::LLVMPointerType> {
50 void OpenACCDialect::initialize() {
53 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
56 #define GET_ATTRDEF_LIST
57 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
60 #define GET_TYPEDEF_LIST
61 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
67 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
68 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
77 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
83 mlir::acc::DeviceType deviceType) {
87 for (
auto attr : *arrayAttr) {
88 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
89 if (deviceTypeAttr.getValue() == deviceType)
97 std::optional<mlir::ArrayAttr> deviceTypes) {
102 llvm::interleaveComma(*deviceTypes, p,
108 mlir::acc::DeviceType deviceType) {
109 unsigned segmentIdx = 0;
110 for (
auto attr : segments) {
111 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
112 if (deviceTypeAttr.getValue() == deviceType)
113 return std::make_optional(segmentIdx);
123 mlir::acc::DeviceType deviceType) {
125 return range.take_front(0);
126 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
127 int32_t nbOperandsBefore = 0;
128 for (
unsigned i = 0; i < *pos; ++i)
129 nbOperandsBefore += (*segments)[i];
130 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
132 return range.take_front(0);
139 std::optional<mlir::ArrayAttr> hasWaitDevnum,
140 mlir::acc::DeviceType deviceType) {
143 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
144 if (hasWaitDevnum->getValue()[*pos])
155 std::optional<mlir::ArrayAttr> hasWaitDevnum,
156 mlir::acc::DeviceType deviceType) {
161 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
162 if (hasWaitDevnum && *hasWaitDevnum) {
163 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
164 if (boolAttr.getValue())
165 return range.drop_front(1);
171 template <
typename Op>
173 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
175 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
180 op.hasAsyncOnly(dtype))
181 return op.
emitError(
"async attribute cannot appear with asyncOperand");
186 op.hasWaitOnly(dtype))
187 return op.
emitError(
"wait attribute cannot appear with waitOperands");
196 auto extent = getExtent();
197 auto upperbound = getUpperbound();
198 if (!extent && !upperbound)
199 return emitError(
"expected extent or upperbound.");
209 "data clause associated with private operation must match its intent");
218 return emitError(
"data clause associated with firstprivate operation must "
228 return emitError(
"data clause associated with reduction operation must "
238 return emitError(
"data clause associated with deviceptr operation must "
249 "data clause associated with present operation must match its intent");
258 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
263 "data clause associated with copyin operation must match its intent"
264 " or specify original clause this operation was decomposed from");
268 bool acc::CopyinOp::isCopyinReadonly() {
269 return getDataClause() == acc::DataClause::acc_copyin_readonly;
282 "data clause associated with create operation must match its intent"
283 " or specify original clause this operation was decomposed from");
287 bool acc::CreateOp::isCreateZero() {
289 return getDataClause() == acc::DataClause::acc_create_zero ||
298 return emitError(
"data clause associated with no_create operation must "
309 "data clause associated with attach operation must match its intent");
318 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
319 return emitError(
"data clause associated with device_resident operation "
320 "must match its intent");
331 "data clause associated with link operation must match its intent");
345 "data clause associated with copyout operation must match its intent"
346 " or specify original clause this operation was decomposed from");
348 return emitError(
"must have both host and device pointers");
352 bool acc::CopyoutOp::isCopyoutZero() {
367 getDataClause() != acc::DataClause::acc_declare_device_resident &&
370 "data clause associated with delete operation must match its intent"
371 " or specify original clause this operation was decomposed from");
373 return emitError(
"must have device pointer");
385 "data clause associated with detach operation must match its intent"
386 " or specify original clause this operation was decomposed from");
388 return emitError(
"must have device pointer");
400 "data clause associated with host operation must match its intent"
401 " or specify original clause this operation was decomposed from");
403 return emitError(
"must have both host and device pointers");
414 "data clause associated with device operation must match its intent"
415 " or specify original clause this operation was decomposed from");
426 "data clause associated with use_device operation must match its intent"
427 " or specify original clause this operation was decomposed from");
439 "data clause associated with cache operation must match its intent"
440 " or specify original clause this operation was decomposed from");
444 template <
typename StructureOp>
446 unsigned nRegions = 1) {
449 for (
unsigned i = 0; i < nRegions; ++i)
450 regions.push_back(state.addRegion());
452 for (
Region *region : regions)
460 return isa<acc::ParallelOp, acc::LoopOp>(op);
467 template <
typename OpTy>
474 Value ifCond = op.getIfCond();
478 IntegerAttr constAttr;
481 if (constAttr.getInt())
482 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
494 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
506 template <
typename OpTy>
507 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
513 Value ifCond = op.getIfCond();
517 IntegerAttr constAttr;
520 if (constAttr.getInt())
521 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
536 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
538 if (optional && region.
empty())
542 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
546 return op->
emitOpError() <<
"expects " << regionName
549 << regionType <<
" type";
552 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
553 if (yieldOp.getOperands().size() != 1 ||
554 yieldOp.getOperands().getTypes()[0] != type)
555 return op->
emitOpError() <<
"expects " << regionName
557 "yield a value of the "
558 << regionType <<
" type";
566 "privatization",
"init", getType(),
570 *
this, getDestroyRegion(),
"privatization",
"destroy", getType(),
582 "privatization",
"init", getType(),
586 if (getCopyRegion().empty())
587 return emitOpError() <<
"expects non-empty copy region";
592 return emitOpError() <<
"expects copy region with two arguments of the "
593 "privatization type";
595 if (getDestroyRegion().empty())
599 "privatization",
"destroy",
616 if (getCombinerRegion().empty())
617 return emitOpError() <<
"expects non-empty combiner region";
619 Block &reductionBlock = getCombinerRegion().
front();
623 return emitOpError() <<
"expects combiner region with the first two "
624 <<
"arguments of the reduction type";
626 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
627 if (yieldOp.getOperands().size() != 1 ||
628 yieldOp.getOperands().getTypes()[0] != getType())
629 return emitOpError() <<
"expects combiner region to yield a value "
630 "of the reduction type";
646 if (parser.parseAttribute(attributes.emplace_back()) ||
647 parser.parseArrow() ||
648 parser.parseOperand(operands.emplace_back()) ||
649 parser.parseColonType(types.emplace_back()))
663 std::optional<mlir::ArrayAttr> attributes) {
664 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
665 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
666 << std::get<1>(it).getType();
675 template <
typename Op>
679 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
680 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
681 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
682 operand.getDefiningOp()))
684 "expect data entry/exit operation or acc.getdeviceptr "
689 template <
typename Op>
693 llvm::StringRef symbolName,
bool checkOperandType =
true) {
694 if (!operands.empty()) {
695 if (!attributes || attributes->size() != operands.size())
697 <<
"expected as many " << symbolName <<
" symbol reference as "
698 << operandName <<
" operands";
702 <<
"unexpected " << symbolName <<
" symbol reference";
707 for (
auto args : llvm::zip(operands, *attributes)) {
710 if (!set.insert(operand).second)
712 << operandName <<
" operand appears more than once";
715 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
716 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
719 <<
"expected symbol reference " << symbolRef <<
" to point to a "
720 << operandName <<
" declaration";
722 if (checkOperandType && decl.getType() && decl.getType() != varType)
723 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
724 <<
") to be the same type as " << operandName
725 <<
" declaration (" << decl.getType() <<
")";
731 unsigned ParallelOp::getNumDataOperands() {
732 return getReductionOperands().size() + getGangPrivateOperands().size() +
733 getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
736 Value ParallelOp::getDataOperand(
unsigned i) {
737 unsigned numOptional = getAsyncOperands().size();
738 numOptional += getNumGangs().size();
739 numOptional += getNumWorkers().size();
740 numOptional += getVectorLength().size();
741 numOptional += getIfCond() ? 1 : 0;
742 numOptional += getSelfCond() ? 1 : 0;
743 return getOperand(getWaitOperands().size() + numOptional + i);
746 template <
typename Op>
748 ArrayAttr deviceTypes,
749 llvm::StringRef keyword) {
750 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
751 return op.
emitOpError() << keyword <<
" operands count must match "
752 << keyword <<
" device_type count";
756 template <
typename Op>
759 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
760 std::size_t numOperandsInSegments = 0;
766 if (maxInSegment != 0 && segCount > maxInSegment)
767 return op.
emitOpError() << keyword <<
" expects a maximum of "
768 << maxInSegment <<
" values per segment";
769 numOperandsInSegments += segCount;
771 if (numOperandsInSegments != operands.size())
773 << keyword <<
" operand count does not match count in segments";
774 if (deviceTypes.getValue().size() != (
size_t)segments.size())
776 << keyword <<
" segment count does not match device_type count";
781 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
782 *
this, getPrivatizations(), getGangPrivateOperands(),
"private",
783 "privatizations",
false)))
785 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
786 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
787 "reductions",
false)))
791 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
792 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
796 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
797 getWaitOperandsDeviceTypeAttr(),
"wait")))
801 getNumWorkersDeviceTypeAttr(),
806 getVectorLengthDeviceTypeAttr(),
811 getAsyncOperandsDeviceTypeAttr(),
815 if (
failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*
this)))
818 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
824 mlir::acc::DeviceType deviceType) {
827 if (
auto pos =
findSegment(*arrayAttr, deviceType))
832 bool acc::ParallelOp::hasAsyncOnly() {
836 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
844 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
846 getAsyncOperands(), deviceType);
849 mlir::Value acc::ParallelOp::getNumWorkersValue() {
854 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
859 mlir::Value acc::ParallelOp::getVectorLengthValue() {
864 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
866 getVectorLength(), deviceType);
874 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
876 getNumGangsSegments(), deviceType);
879 bool acc::ParallelOp::hasWaitOnly() {
883 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
892 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
894 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
895 getHasWaitDevnum(), deviceType);
902 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
904 getWaitOperandsSegments(), getHasWaitDevnum(),
920 int32_t crtOperandsSize = operands.size();
923 if (parser.parseOperand(operands.emplace_back()) ||
924 parser.parseColonType(types.emplace_back()))
929 seg.push_back(operands.size() - crtOperandsSize);
953 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
955 p <<
" [" << attr <<
"]";
960 std::optional<mlir::ArrayAttr> deviceTypes,
961 std::optional<mlir::DenseI32ArrayAttr> segments) {
965 llvm::interleaveComma(
966 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
967 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
987 int32_t crtOperandsSize = operands.size();
991 if (parser.parseOperand(operands.emplace_back()) ||
992 parser.parseColonType(types.emplace_back()))
998 seg.push_back(operands.size() - crtOperandsSize);
1024 std::optional<mlir::DenseI32ArrayAttr> segments) {
1026 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1028 llvm::interleaveComma(
1029 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1030 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1043 mlir::ArrayAttr &keywordOnly) {
1047 bool needCommaBeforeOperands =
false;
1060 if (parser.parseAttribute(keywordAttrs.emplace_back()))
1067 needCommaBeforeOperands =
true;
1077 int32_t crtOperandsSize = operands.size();
1089 if (parser.parseOperand(operands.emplace_back()) ||
1090 parser.parseColonType(types.emplace_back()))
1096 seg.push_back(operands.size() - crtOperandsSize);
1125 if (attrs->size() != 1)
1127 if (
auto deviceTypeAttr =
1128 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1135 std::optional<mlir::ArrayAttr> deviceTypes,
1136 std::optional<mlir::DenseI32ArrayAttr> segments,
1137 std::optional<mlir::ArrayAttr> hasDevNum,
1138 std::optional<mlir::ArrayAttr> keywordOnly) {
1150 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1152 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1153 if (boolAttr && boolAttr.getValue())
1155 llvm::interleaveComma(
1156 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1157 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1173 if (parser.parseOperand(operands.emplace_back()) ||
1174 parser.parseColonType(types.emplace_back()))
1176 if (succeeded(parser.parseOptionalLSquare())) {
1177 if (parser.parseAttribute(attributes.emplace_back()) ||
1178 parser.parseRSquare())
1181 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1182 parser.getContext(), mlir::acc::DeviceType::None));
1196 std::optional<mlir::ArrayAttr> deviceTypes) {
1199 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
1200 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
1209 mlir::ArrayAttr &keywordOnlyDeviceType) {
1212 bool needCommaBeforeOperands =
false;
1218 keywordOnlyDeviceType =
1227 if (parser.parseAttribute(
1228 keywordOnlyDeviceTypeAttributes.emplace_back()))
1235 needCommaBeforeOperands =
true;
1243 if (parser.parseOperand(operands.emplace_back()) ||
1244 parser.parseColonType(types.emplace_back()))
1246 if (succeeded(parser.parseOptionalLSquare())) {
1247 if (parser.parseAttribute(attributes.emplace_back()) ||
1248 parser.parseRSquare())
1251 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1252 parser.getContext(), mlir::acc::DeviceType::None));
1258 if (
failed(parser.parseRParen()))
1270 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1272 if (operands.begin() == operands.end() &&
1288 mlir::acc::CombinedConstructsTypeAttr &attr) {
1294 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1297 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1300 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1303 "expected compute construct name");
1314 mlir::acc::CombinedConstructsTypeAttr attr) {
1316 switch (attr.getValue()) {
1317 case mlir::acc::CombinedConstructsType::KernelsLoop:
1318 p <<
"combined(kernels)";
1320 case mlir::acc::CombinedConstructsType::ParallelLoop:
1321 p <<
"combined(parallel)";
1323 case mlir::acc::CombinedConstructsType::SerialLoop:
1324 p <<
"combined(serial)";
1334 unsigned SerialOp::getNumDataOperands() {
1335 return getReductionOperands().size() + getGangPrivateOperands().size() +
1336 getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
1339 Value SerialOp::getDataOperand(
unsigned i) {
1340 unsigned numOptional = getAsyncOperands().size();
1341 numOptional += getIfCond() ? 1 : 0;
1342 numOptional += getSelfCond() ? 1 : 0;
1343 return getOperand(getWaitOperands().size() + numOptional + i);
1346 bool acc::SerialOp::hasAsyncOnly() {
1350 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1358 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1360 getAsyncOperands(), deviceType);
1363 bool acc::SerialOp::hasWaitOnly() {
1367 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1376 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1378 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1379 getHasWaitDevnum(), deviceType);
1386 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1388 getWaitOperandsSegments(), getHasWaitDevnum(),
1393 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1394 *
this, getPrivatizations(), getGangPrivateOperands(),
"private",
1395 "privatizations",
false)))
1397 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1398 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1399 "reductions",
false)))
1403 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1404 getWaitOperandsDeviceTypeAttr(),
"wait")))
1408 getAsyncOperandsDeviceTypeAttr(),
1412 if (
failed(checkWaitAndAsyncConflict<acc::SerialOp>(*
this)))
1415 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
1422 unsigned KernelsOp::getNumDataOperands() {
1423 return getDataClauseOperands().size();
1426 Value KernelsOp::getDataOperand(
unsigned i) {
1427 unsigned numOptional = getAsyncOperands().size();
1428 numOptional += getWaitOperands().size();
1429 numOptional += getNumGangs().size();
1430 numOptional += getNumWorkers().size();
1431 numOptional += getVectorLength().size();
1432 numOptional += getIfCond() ? 1 : 0;
1433 numOptional += getSelfCond() ? 1 : 0;
1434 return getOperand(numOptional + i);
1437 bool acc::KernelsOp::hasAsyncOnly() {
1441 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1449 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1451 getAsyncOperands(), deviceType);
1454 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1459 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1464 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1469 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1471 getVectorLength(), deviceType);
1479 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1481 getNumGangsSegments(), deviceType);
1484 bool acc::KernelsOp::hasWaitOnly() {
1488 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1497 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1499 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1500 getHasWaitDevnum(), deviceType);
1507 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1509 getWaitOperandsSegments(), getHasWaitDevnum(),
1515 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1516 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1520 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1521 getWaitOperandsDeviceTypeAttr(),
"wait")))
1525 getNumWorkersDeviceTypeAttr(),
1530 getVectorLengthDeviceTypeAttr(),
1535 getAsyncOperandsDeviceTypeAttr(),
1539 if (
failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*
this)))
1542 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
1550 if (getDataClauseOperands().empty())
1551 return emitError(
"at least one operand must appear on the host_data "
1554 for (
mlir::Value operand : getDataClauseOperands())
1555 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1556 return emitError(
"expect data entry operation as defining op");
1562 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1574 bool &needCommaBetweenValues,
bool &newValue) {
1581 attributes.push_back(gangArgType);
1582 needCommaBetweenValues =
true;
1593 mlir::ArrayAttr &gangOnlyDeviceType) {
1598 bool needCommaBetweenValues =
false;
1599 bool needCommaBeforeOperands =
false;
1605 gangOnlyDeviceType =
1614 if (parser.parseAttribute(
1615 gangOnlyDeviceTypeAttributes.emplace_back()))
1622 needCommaBeforeOperands =
true;
1626 mlir::acc::GangArgType::Num);
1628 mlir::acc::GangArgType::Dim);
1630 parser.
getContext(), mlir::acc::GangArgType::Static);
1633 if (needCommaBeforeOperands) {
1634 needCommaBeforeOperands =
false;
1641 int32_t crtOperandsSize = gangOperands.size();
1643 bool newValue =
false;
1644 bool needValue =
false;
1645 if (needCommaBetweenValues) {
1653 gangOperands, gangOperandsType,
1654 gangArgTypeAttributes, argNum,
1655 needCommaBetweenValues, newValue)))
1658 gangOperands, gangOperandsType,
1659 gangArgTypeAttributes, argDim,
1660 needCommaBetweenValues, newValue)))
1663 gangOperands, gangOperandsType,
1664 gangArgTypeAttributes, argStatic,
1665 needCommaBetweenValues, newValue)))
1668 if (!newValue && needValue) {
1670 "new value expected after comma");
1678 if (gangOperands.empty())
1681 "expect at least one of num, dim or static values");
1687 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
1695 seg.push_back(gangOperands.size() - crtOperandsSize);
1703 gangArgTypeAttributes.end());
1708 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1717 std::optional<mlir::ArrayAttr> gangArgTypes,
1718 std::optional<mlir::ArrayAttr> deviceTypes,
1719 std::optional<mlir::DenseI32ArrayAttr> segments,
1720 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1722 if (operands.begin() == operands.end() &&
1737 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1739 llvm::interleaveComma(
1740 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1741 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1742 (*gangArgTypes)[opIdx]);
1743 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1744 p << LoopOp::getGangNumKeyword();
1745 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1746 p << LoopOp::getGangDimKeyword();
1747 else if (gangArgTypeAttr.getValue() ==
1748 mlir::acc::GangArgType::Static)
1749 p << LoopOp::getGangStaticKeyword();
1750 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
1761 std::optional<mlir::ArrayAttr> segments,
1762 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
1765 for (
auto attr : *segments) {
1766 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1767 if (deviceTypes.contains(deviceTypeAttr.getValue()))
1769 deviceTypes.insert(deviceTypeAttr.getValue());
1776 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
1779 for (
auto attr : deviceTypes) {
1780 auto deviceTypeAttr =
1781 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
1782 if (!deviceTypeAttr)
1784 if (crtDeviceTypes.contains(deviceTypeAttr.getValue()))
1786 crtDeviceTypes.insert(deviceTypeAttr.getValue());
1792 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
1793 (getUpperbound().size() != getInclusiveUpperbound()->size()))
1794 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
1795 <<
" as upperbound size";
1798 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
1799 return emitOpError() <<
"collapse device_type attr must be define when"
1800 <<
" collapse attr is present";
1802 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
1803 getCollapseAttr().getValue().size() !=
1804 getCollapseDeviceTypeAttr().getValue().size())
1805 return emitOpError() <<
"collapse attribute count must match collapse"
1806 <<
" device_type count";
1808 return emitOpError()
1809 <<
"duplicate device_type found in collapseDeviceType attribute";
1812 if (!getGangOperands().empty()) {
1813 if (!getGangOperandsArgType())
1814 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
1815 <<
" when gang operands are present";
1817 if (getGangOperands().size() !=
1818 getGangOperandsArgTypeAttr().getValue().size())
1819 return emitOpError() <<
"gangOperandsArgType attribute count must match"
1820 <<
" gangOperands count";
1823 return emitOpError() <<
"duplicate device_type found in gang attribute";
1826 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
1827 getGangOperandsDeviceTypeAttr(),
"gang")))
1832 return emitOpError() <<
"duplicate device_type found in worker attribute";
1834 return emitOpError() <<
"duplicate device_type found in "
1835 "workerNumOperandsDeviceType attribute";
1837 getWorkerNumOperandsDeviceTypeAttr(),
1843 return emitOpError() <<
"duplicate device_type found in vector attribute";
1845 return emitOpError() <<
"duplicate device_type found in "
1846 "vectorOperandsDeviceType attribute";
1848 getVectorOperandsDeviceTypeAttr(),
1853 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
1854 getTileOperandsDeviceTypeAttr(),
"tile")))
1858 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
1862 return emitError() <<
"only one of \"" << acc::LoopOp::getAutoAttrStrName()
1863 <<
"\", " << getIndependentAttrName() <<
", "
1865 <<
" can be present at the same time";
1870 for (
auto attr : getSeqAttr()) {
1871 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1872 if (hasVector(deviceTypeAttr.getValue()) ||
1873 getVectorValue(deviceTypeAttr.getValue()) ||
1874 hasWorker(deviceTypeAttr.getValue()) ||
1875 getWorkerValue(deviceTypeAttr.getValue()) ||
1876 hasGang(deviceTypeAttr.getValue()) ||
1877 getGangValue(mlir::acc::GangArgType::Num,
1878 deviceTypeAttr.getValue()) ||
1879 getGangValue(mlir::acc::GangArgType::Dim,
1880 deviceTypeAttr.getValue()) ||
1881 getGangValue(mlir::acc::GangArgType::Static,
1882 deviceTypeAttr.getValue()))
1884 <<
"gang, worker or vector cannot appear with the seq attr";
1888 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1889 *
this, getPrivatizations(), getPrivateOperands(),
"private",
1890 "privatizations",
false)))
1893 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1894 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1895 "reductions",
false)))
1898 if (getCombined().has_value() &&
1899 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
1900 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
1901 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
1902 return emitError(
"unexpected combined constructs attribute");
1906 if (getRegion().empty())
1907 return emitError(
"expected non-empty body.");
1912 unsigned LoopOp::getNumDataOperands() {
1913 return getReductionOperands().size() + getPrivateOperands().size();
1916 Value LoopOp::getDataOperand(
unsigned i) {
1917 unsigned numOptional =
1918 getLowerbound().size() + getUpperbound().size() + getStep().size();
1919 numOptional += getGangOperands().size();
1920 numOptional += getVectorOperands().size();
1921 numOptional += getWorkerNumOperands().size();
1922 numOptional += getTileOperands().size();
1923 numOptional += getCacheOperands().size();
1924 return getOperand(numOptional + i);
1929 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
1933 bool LoopOp::hasIndependent() {
1937 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
1943 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
1951 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
1953 getVectorOperands(), deviceType);
1958 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
1966 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
1968 getWorkerNumOperands(), deviceType);
1973 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
1982 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
1984 getTileOperandsSegments(), deviceType);
1987 std::optional<int64_t> LoopOp::getCollapseValue() {
1991 std::optional<int64_t>
1992 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
1993 if (!getCollapseAttr())
1994 return std::nullopt;
1995 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
1997 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
1998 return intAttr.getValue().getZExtValue();
2000 return std::nullopt;
2003 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2007 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2008 mlir::acc::DeviceType deviceType) {
2009 if (getGangOperands().empty())
2011 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
2012 int32_t nbOperandsBefore = 0;
2013 for (
unsigned i = 0; i < *pos; ++i)
2014 nbOperandsBefore += (*getGangOperandsSegments())[i];
2017 .drop_front(nbOperandsBefore)
2018 .take_front((*getGangOperandsSegments())[*pos]);
2020 int32_t argTypeIdx = nbOperandsBefore;
2021 for (
auto value : values) {
2022 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2023 (*getGangOperandsArgType())[argTypeIdx]);
2024 if (gangArgTypeAttr.getValue() == gangArgType)
2034 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2039 return {&getRegion()};
2083 if (!regionArgs.empty()) {
2084 p << acc::LoopOp::getControlKeyword() <<
"(";
2085 llvm::interleaveComma(regionArgs, p,
2087 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
2088 << upperbound <<
" : " << upperboundType <<
") "
2089 <<
" step (" << steps <<
" : " << stepType <<
") ";
2102 if (getOperands().empty() && !getDefaultAttr())
2103 return emitError(
"at least one operand or the default attribute "
2104 "must appear on the data operation");
2106 for (
mlir::Value operand : getDataClauseOperands())
2107 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2108 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2109 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2110 operand.getDefiningOp()))
2111 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2114 if (
failed(checkWaitAndAsyncConflict<acc::DataOp>(*
this)))
2120 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
2122 Value DataOp::getDataOperand(
unsigned i) {
2123 unsigned numOptional = getIfCond() ? 1 : 0;
2124 numOptional += getAsyncOperands().size() ? 1 : 0;
2125 numOptional += getWaitOperands().size();
2126 return getOperand(numOptional + i);
2129 bool acc::DataOp::hasAsyncOnly() {
2133 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2141 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2143 getAsyncOperands(), deviceType);
2148 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2157 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2159 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2160 getHasWaitDevnum(), deviceType);
2167 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2169 getWaitOperandsSegments(), getHasWaitDevnum(),
2181 if (getDataClauseOperands().empty())
2182 return emitError(
"at least one operand must be present in dataOperands on "
2183 "the exit data operation");
2187 if (getAsyncOperand() && getAsync())
2188 return emitError(
"async attribute cannot appear with asyncOperand");
2192 if (!getWaitOperands().empty() && getWait())
2193 return emitError(
"wait attribute cannot appear with waitOperands");
2195 if (getWaitDevnum() && getWaitOperands().empty())
2196 return emitError(
"wait_devnum cannot appear without waitOperands");
2201 unsigned ExitDataOp::getNumDataOperands() {
2202 return getDataClauseOperands().size();
2205 Value ExitDataOp::getDataOperand(
unsigned i) {
2206 unsigned numOptional = getIfCond() ? 1 : 0;
2207 numOptional += getAsyncOperand() ? 1 : 0;
2208 numOptional += getWaitDevnum() ? 1 : 0;
2209 return getOperand(getWaitOperands().size() + numOptional + i);
2214 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
2225 if (getDataClauseOperands().empty())
2226 return emitError(
"at least one operand must be present in dataOperands on "
2227 "the enter data operation");
2231 if (getAsyncOperand() && getAsync())
2232 return emitError(
"async attribute cannot appear with asyncOperand");
2236 if (!getWaitOperands().empty() && getWait())
2237 return emitError(
"wait attribute cannot appear with waitOperands");
2239 if (getWaitDevnum() && getWaitOperands().empty())
2240 return emitError(
"wait_devnum cannot appear without waitOperands");
2242 for (
mlir::Value operand : getDataClauseOperands())
2243 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2244 operand.getDefiningOp()))
2245 return emitError(
"expect data entry operation as defining op");
2250 unsigned EnterDataOp::getNumDataOperands() {
2251 return getDataClauseOperands().size();
2254 Value EnterDataOp::getDataOperand(
unsigned i) {
2255 unsigned numOptional = getIfCond() ? 1 : 0;
2256 numOptional += getAsyncOperand() ? 1 : 0;
2257 numOptional += getWaitDevnum() ? 1 : 0;
2258 return getOperand(getWaitOperands().size() + numOptional + i);
2263 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
2282 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2289 if (
Value writeVal = op.getWriteOpVal()) {
2299 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
2305 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2306 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2308 return dyn_cast<AtomicReadOp>(getSecondOp());
2311 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2312 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2314 return dyn_cast<AtomicWriteOp>(getSecondOp());
2317 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2318 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2320 return dyn_cast<AtomicUpdateOp>(getSecondOp());
2323 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
2329 template <
typename Op>
2332 bool requireAtLeastOneOperand =
true) {
2333 if (operands.empty() && requireAtLeastOneOperand)
2336 "at least one operand must appear on the declare operation");
2339 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2340 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2341 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2342 operand.getDefiningOp()))
2344 "expect valid declare data entry operation or acc.getdeviceptr "
2348 assert(varPtr &&
"declare operands can only be data entry operations which "
2349 "must have varPtr");
2350 std::optional<mlir::acc::DataClause> dataClauseOptional{
2352 assert(dataClauseOptional.has_value() &&
2353 "declare operands can only be data entry operations which must have "
2357 if (!varPtr.getDefiningOp())
2361 auto declareAttribute{
2363 if (!declareAttribute)
2365 "expect declare attribute on variable in declare operation");
2367 auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2368 if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2370 "expect matching declare attribute on variable in declare operation");
2377 if (declAttr.getImplicit() &&
2380 "implicitness must match between declare op and flag on variable");
2414 acc::DeviceType dtype) {
2415 unsigned parallelism = 0;
2416 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2417 parallelism += op.hasWorker(dtype) ? 1 : 0;
2418 parallelism += op.hasVector(dtype) ? 1 : 0;
2419 parallelism += op.hasSeq(dtype) ? 1 : 0;
2424 unsigned baseParallelism =
2427 if (baseParallelism > 1)
2428 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
2429 "be present at the same time";
2431 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2433 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
2438 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2439 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
2440 "be present at the same time";
2447 mlir::ArrayAttr &deviceTypes) {
2452 if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2454 if (failed(parser.parseOptionalLSquare())) {
2455 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2456 parser.getContext(), mlir::acc::DeviceType::None));
2458 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2459 parser.parseRSquare())
2467 deviceTypes =
ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2473 std::optional<mlir::ArrayAttr> bindName,
2474 std::optional<mlir::ArrayAttr> deviceTypes) {
2475 llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2476 [&](
const auto &pair) {
2477 p << std::get<0>(pair);
2483 mlir::ArrayAttr &gang,
2484 mlir::ArrayAttr &gangDim,
2485 mlir::ArrayAttr &gangDimDeviceTypes) {
2488 gangDimDeviceTypeAttrs;
2489 bool needCommaBeforeOperands =
false;
2502 if (parser.parseAttribute(gangAttrs.emplace_back()))
2509 needCommaBeforeOperands =
true;
2516 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2517 parser.parseColon() ||
2518 parser.parseAttribute(gangDimAttrs.emplace_back()))
2520 if (succeeded(parser.parseOptionalLSquare())) {
2521 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2522 parser.parseRSquare())
2525 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2526 parser.getContext(), mlir::acc::DeviceType::None));
2532 if (
failed(parser.parseRParen()))
2537 gangDimDeviceTypes =
2544 std::optional<mlir::ArrayAttr> gang,
2545 std::optional<mlir::ArrayAttr> gangDim,
2546 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2549 gang->size() == 1) {
2550 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2563 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2564 [&](
const auto &pair) {
2565 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
2566 p << std::get<0>(pair);
2574 mlir::ArrayAttr &deviceTypes) {
2587 if (parser.parseAttribute(attributes.emplace_back()))
2601 std::optional<mlir::ArrayAttr> deviceTypes) {
2604 auto deviceTypeAttr =
2605 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2615 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2623 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2629 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2635 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2639 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2643 std::optional<llvm::StringRef>
2644 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2646 return std::nullopt;
2647 if (
auto pos =
findSegment(*getBindNameDeviceType(), deviceType)) {
2648 auto attr = (*getBindName())[*pos];
2649 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2650 return stringAttr.getValue();
2652 return std::nullopt;
2657 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2661 std::optional<int64_t> RoutineOp::getGangDimValue() {
2665 std::optional<int64_t>
2666 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2668 return std::nullopt;
2669 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
2670 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2671 return intAttr.getInt();
2673 return std::nullopt;
2684 return emitOpError(
"cannot be nested in a compute operation");
2696 return emitOpError(
"cannot be nested in a compute operation");
2708 return emitOpError(
"cannot be nested in a compute operation");
2709 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2710 return emitOpError(
"at least one default_async, device_num, or device_type "
2711 "operand must appear");
2721 if (getDataClauseOperands().empty())
2722 return emitError(
"at least one value must be present in dataOperands");
2725 getAsyncOperandsDeviceTypeAttr(),
2730 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2731 getWaitOperandsDeviceTypeAttr(),
"wait")))
2734 if (
failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*
this)))
2737 for (
mlir::Value operand : getDataClauseOperands())
2738 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2739 operand.getDefiningOp()))
2740 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2746 unsigned UpdateOp::getNumDataOperands() {
2747 return getDataClauseOperands().size();
2750 Value UpdateOp::getDataOperand(
unsigned i) {
2751 unsigned numOptional = getAsyncOperands().size();
2752 numOptional += getIfCond() ? 1 : 0;
2753 return getOperand(getWaitOperands().size() + numOptional + i);
2758 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
2761 bool UpdateOp::hasAsyncOnly() {
2765 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2773 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2777 if (
auto pos =
findSegment(*getAsyncOperandsDeviceType(), deviceType))
2778 return getAsyncOperands()[*pos];
2783 bool UpdateOp::hasWaitOnly() {
2787 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2796 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2798 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2799 getHasWaitDevnum(), deviceType);
2806 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2808 getWaitOperandsSegments(), getHasWaitDevnum(),
2819 if (getAsyncOperand() && getAsync())
2820 return emitError(
"async attribute cannot appear with asyncOperand");
2822 if (getWaitDevnum() && getWaitOperands().empty())
2823 return emitError(
"wait_devnum cannot appear without waitOperands");
2828 #define GET_OP_CLASSES
2829 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
2831 #define GET_ATTRDEF_CLASSES
2832 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
2834 #define GET_TYPEDEF_CLASSES
2835 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
2844 [&](
auto entry) {
return entry.getVarPtr(); })
2845 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
2846 [&](
auto exit) {
return exit.getVarPtr(); })
2854 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
2863 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
2873 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
2875 dataClause.getBounds().begin(), dataClause.getBounds().end());
2886 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
2893 std::optional<mlir::acc::DataClause>
2898 .Case<ACC_DATA_ENTRY_OPS>(
2899 [&](
auto entry) {
return entry.getDataClause(); })
2907 [&](
auto entry) {
return entry.getImplicit(); })
2916 [&](
auto entry) {
return entry.getDataClauseOperands(); })
2918 return dataOperands;
2926 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
2928 return dataOperands;
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
This class represents success/failure for parsing-like operations that find it important to chain tog...
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
LogicalResult failure(bool isFailure=true)
Utility function to generate a LogicalResult.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool succeeded(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a success value.
LogicalResult success(bool isSuccess=true)
Utility function to generate a LogicalResult.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
bool failed(LogicalResult result)
Utility function that returns true if the provided LogicalResult corresponds to a failure value.
This class represents an efficient way to signal success or failure.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.