19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/TypeSwitch.h"
25 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
26 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
27 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
28 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
29 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
32 struct MemRefPointerLikeModel
33 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
36 return llvm::cast<MemRefType>(pointer).getElementType();
40 struct LLVMPointerPointerLikeModel
41 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
42 LLVM::LLVMPointerType> {
51 void OpenACCDialect::initialize() {
54 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
57 #define GET_ATTRDEF_LIST
58 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
61 #define GET_TYPEDEF_LIST
62 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
68 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
69 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
78 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
84 mlir::acc::DeviceType deviceType) {
88 for (
auto attr : *arrayAttr) {
89 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
90 if (deviceTypeAttr.getValue() == deviceType)
98 std::optional<mlir::ArrayAttr> deviceTypes) {
103 llvm::interleaveComma(*deviceTypes, p,
109 mlir::acc::DeviceType deviceType) {
110 unsigned segmentIdx = 0;
111 for (
auto attr : segments) {
112 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
113 if (deviceTypeAttr.getValue() == deviceType)
114 return std::make_optional(segmentIdx);
124 mlir::acc::DeviceType deviceType) {
126 return range.take_front(0);
127 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
128 int32_t nbOperandsBefore = 0;
129 for (
unsigned i = 0; i < *pos; ++i)
130 nbOperandsBefore += (*segments)[i];
131 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
133 return range.take_front(0);
140 std::optional<mlir::ArrayAttr> hasWaitDevnum,
141 mlir::acc::DeviceType deviceType) {
144 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
145 if (hasWaitDevnum->getValue()[*pos])
156 std::optional<mlir::ArrayAttr> hasWaitDevnum,
157 mlir::acc::DeviceType deviceType) {
162 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
163 if (hasWaitDevnum && *hasWaitDevnum) {
164 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
165 if (boolAttr.getValue())
166 return range.drop_front(1);
172 template <
typename Op>
174 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
176 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
181 op.hasAsyncOnly(dtype))
182 return op.
emitError(
"async attribute cannot appear with asyncOperand");
187 op.hasWaitOnly(dtype))
188 return op.
emitError(
"wait attribute cannot appear with waitOperands");
197 auto extent = getExtent();
198 auto upperbound = getUpperbound();
199 if (!extent && !upperbound)
200 return emitError(
"expected extent or upperbound.");
210 "data clause associated with private operation must match its intent");
219 return emitError(
"data clause associated with firstprivate operation must "
229 return emitError(
"data clause associated with reduction operation must "
239 return emitError(
"data clause associated with deviceptr operation must "
250 "data clause associated with present operation must match its intent");
259 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
264 "data clause associated with copyin operation must match its intent"
265 " or specify original clause this operation was decomposed from");
269 bool acc::CopyinOp::isCopyinReadonly() {
270 return getDataClause() == acc::DataClause::acc_copyin_readonly;
283 "data clause associated with create operation must match its intent"
284 " or specify original clause this operation was decomposed from");
288 bool acc::CreateOp::isCreateZero() {
290 return getDataClause() == acc::DataClause::acc_create_zero ||
299 return emitError(
"data clause associated with no_create operation must "
310 "data clause associated with attach operation must match its intent");
319 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
320 return emitError(
"data clause associated with device_resident operation "
321 "must match its intent");
332 "data clause associated with link operation must match its intent");
346 "data clause associated with copyout operation must match its intent"
347 " or specify original clause this operation was decomposed from");
349 return emitError(
"must have both host and device pointers");
353 bool acc::CopyoutOp::isCopyoutZero() {
368 getDataClause() != acc::DataClause::acc_declare_device_resident &&
371 "data clause associated with delete operation must match its intent"
372 " or specify original clause this operation was decomposed from");
374 return emitError(
"must have device pointer");
386 "data clause associated with detach operation must match its intent"
387 " or specify original clause this operation was decomposed from");
389 return emitError(
"must have device pointer");
401 "data clause associated with host operation must match its intent"
402 " or specify original clause this operation was decomposed from");
404 return emitError(
"must have both host and device pointers");
415 "data clause associated with device operation must match its intent"
416 " or specify original clause this operation was decomposed from");
427 "data clause associated with use_device operation must match its intent"
428 " or specify original clause this operation was decomposed from");
440 "data clause associated with cache operation must match its intent"
441 " or specify original clause this operation was decomposed from");
445 template <
typename StructureOp>
447 unsigned nRegions = 1) {
450 for (
unsigned i = 0; i < nRegions; ++i)
451 regions.push_back(state.addRegion());
453 for (
Region *region : regions)
461 return isa<acc::ParallelOp, acc::LoopOp>(op);
468 template <
typename OpTy>
472 LogicalResult matchAndRewrite(OpTy op,
475 Value ifCond = op.getIfCond();
479 IntegerAttr constAttr;
482 if (constAttr.getInt())
483 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
495 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
507 template <
typename OpTy>
508 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
511 LogicalResult matchAndRewrite(OpTy op,
514 Value ifCond = op.getIfCond();
518 IntegerAttr constAttr;
521 if (constAttr.getInt())
522 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
537 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
539 if (optional && region.
empty())
543 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
547 return op->
emitOpError() <<
"expects " << regionName
550 << regionType <<
" type";
553 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
554 if (yieldOp.getOperands().size() != 1 ||
555 yieldOp.getOperands().getTypes()[0] != type)
556 return op->
emitOpError() <<
"expects " << regionName
558 "yield a value of the "
559 << regionType <<
" type";
565 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
567 "privatization",
"init",
getType(),
571 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
581 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
583 "privatization",
"init",
getType(),
587 if (getCopyRegion().empty())
588 return emitOpError() <<
"expects non-empty copy region";
593 return emitOpError() <<
"expects copy region with two arguments of the "
594 "privatization type";
596 if (getDestroyRegion().empty())
600 "privatization",
"destroy",
611 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
617 if (getCombinerRegion().empty())
618 return emitOpError() <<
"expects non-empty combiner region";
620 Block &reductionBlock = getCombinerRegion().
front();
624 return emitOpError() <<
"expects combiner region with the first two "
625 <<
"arguments of the reduction type";
627 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
628 if (yieldOp.getOperands().size() != 1 ||
629 yieldOp.getOperands().getTypes()[0] !=
getType())
630 return emitOpError() <<
"expects combiner region to yield a value "
631 "of the reduction type";
647 if (parser.parseAttribute(attributes.emplace_back()) ||
648 parser.parseArrow() ||
649 parser.parseOperand(operands.emplace_back()) ||
650 parser.parseColonType(types.emplace_back()))
664 std::optional<mlir::ArrayAttr> attributes) {
665 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
666 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
667 << std::get<1>(it).getType();
676 template <
typename Op>
680 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
681 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
682 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
683 operand.getDefiningOp()))
685 "expect data entry/exit operation or acc.getdeviceptr "
690 template <
typename Op>
694 llvm::StringRef symbolName,
bool checkOperandType =
true) {
695 if (!operands.empty()) {
696 if (!attributes || attributes->size() != operands.size())
698 <<
"expected as many " << symbolName <<
" symbol reference as "
699 << operandName <<
" operands";
703 <<
"unexpected " << symbolName <<
" symbol reference";
708 for (
auto args : llvm::zip(operands, *attributes)) {
711 if (!set.insert(operand).second)
713 << operandName <<
" operand appears more than once";
716 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
717 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
720 <<
"expected symbol reference " << symbolRef <<
" to point to a "
721 << operandName <<
" declaration";
723 if (checkOperandType && decl.getType() && decl.getType() != varType)
724 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
725 <<
") to be the same type as " << operandName
726 <<
" declaration (" << decl.getType() <<
")";
732 unsigned ParallelOp::getNumDataOperands() {
733 return getReductionOperands().size() + getGangPrivateOperands().size() +
734 getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
737 Value ParallelOp::getDataOperand(
unsigned i) {
739 numOptional += getNumGangs().size();
740 numOptional += getNumWorkers().size();
741 numOptional += getVectorLength().size();
742 numOptional += getIfCond() ? 1 : 0;
743 numOptional += getSelfCond() ? 1 : 0;
744 return getOperand(getWaitOperands().size() + numOptional + i);
747 template <
typename Op>
749 ArrayAttr deviceTypes,
750 llvm::StringRef keyword) {
751 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
752 return op.
emitOpError() << keyword <<
" operands count must match "
753 << keyword <<
" device_type count";
757 template <
typename Op>
760 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
761 std::size_t numOperandsInSegments = 0;
767 if (maxInSegment != 0 && segCount > maxInSegment)
768 return op.
emitOpError() << keyword <<
" expects a maximum of "
769 << maxInSegment <<
" values per segment";
770 numOperandsInSegments += segCount;
772 if (numOperandsInSegments != operands.size())
774 << keyword <<
" operand count does not match count in segments";
775 if (deviceTypes.getValue().size() != (
size_t)segments.size())
777 << keyword <<
" segment count does not match device_type count";
782 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
783 *
this, getPrivatizations(), getGangPrivateOperands(),
"private",
784 "privatizations",
false)))
786 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
787 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
788 "reductions",
false)))
792 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
793 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
797 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
798 getWaitOperandsDeviceTypeAttr(),
"wait")))
802 getNumWorkersDeviceTypeAttr(),
807 getVectorLengthDeviceTypeAttr(),
812 getAsyncOperandsDeviceTypeAttr(),
816 if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*
this)))
819 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
825 mlir::acc::DeviceType deviceType) {
828 if (
auto pos =
findSegment(*arrayAttr, deviceType))
833 bool acc::ParallelOp::hasAsyncOnly() {
837 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
845 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
850 mlir::Value acc::ParallelOp::getNumWorkersValue() {
855 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
860 mlir::Value acc::ParallelOp::getVectorLengthValue() {
865 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
867 getVectorLength(), deviceType);
875 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
877 getNumGangsSegments(), deviceType);
880 bool acc::ParallelOp::hasWaitOnly() {
884 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
893 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
895 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
896 getHasWaitDevnum(), deviceType);
903 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
905 getWaitOperandsSegments(), getHasWaitDevnum(),
921 odsBuilder, odsState, asyncOperands,
nullptr,
922 nullptr, waitOperands,
nullptr,
924 nullptr, numGangs,
nullptr,
926 nullptr, vectorLength,
927 nullptr, ifCond, selfCond,
928 nullptr, reductionOperands,
nullptr,
929 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
930 nullptr, dataClauseOperands,
946 int32_t crtOperandsSize = operands.size();
949 if (parser.parseOperand(operands.emplace_back()) ||
950 parser.parseColonType(types.emplace_back()))
955 seg.push_back(operands.size() - crtOperandsSize);
979 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
981 p <<
" [" << attr <<
"]";
986 std::optional<mlir::ArrayAttr> deviceTypes,
987 std::optional<mlir::DenseI32ArrayAttr> segments) {
991 llvm::interleaveComma(
992 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
993 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1013 int32_t crtOperandsSize = operands.size();
1017 if (parser.parseOperand(operands.emplace_back()) ||
1018 parser.parseColonType(types.emplace_back()))
1024 seg.push_back(operands.size() - crtOperandsSize);
1050 std::optional<mlir::DenseI32ArrayAttr> segments) {
1052 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1054 llvm::interleaveComma(
1055 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1056 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1069 mlir::ArrayAttr &keywordOnly) {
1073 bool needCommaBeforeOperands =
false;
1086 if (parser.parseAttribute(keywordAttrs.emplace_back()))
1093 needCommaBeforeOperands =
true;
1096 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1103 int32_t crtOperandsSize = operands.size();
1115 if (parser.parseOperand(operands.emplace_back()) ||
1116 parser.parseColonType(types.emplace_back()))
1122 seg.push_back(operands.size() - crtOperandsSize);
1151 if (attrs->size() != 1)
1153 if (
auto deviceTypeAttr =
1154 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1161 std::optional<mlir::ArrayAttr> deviceTypes,
1162 std::optional<mlir::DenseI32ArrayAttr> segments,
1163 std::optional<mlir::ArrayAttr> hasDevNum,
1164 std::optional<mlir::ArrayAttr> keywordOnly) {
1176 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1178 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1179 if (boolAttr && boolAttr.getValue())
1181 llvm::interleaveComma(
1182 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1183 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1199 if (parser.parseOperand(operands.emplace_back()) ||
1200 parser.parseColonType(types.emplace_back()))
1202 if (succeeded(parser.parseOptionalLSquare())) {
1203 if (parser.parseAttribute(attributes.emplace_back()) ||
1204 parser.parseRSquare())
1207 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1208 parser.getContext(), mlir::acc::DeviceType::None));
1222 std::optional<mlir::ArrayAttr> deviceTypes) {
1225 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
1226 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
1235 mlir::ArrayAttr &keywordOnlyDeviceType) {
1238 bool needCommaBeforeOperands =
false;
1244 keywordOnlyDeviceType =
1253 if (parser.parseAttribute(
1254 keywordOnlyDeviceTypeAttributes.emplace_back()))
1261 needCommaBeforeOperands =
true;
1264 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1269 if (parser.parseOperand(operands.emplace_back()) ||
1270 parser.parseColonType(types.emplace_back()))
1272 if (succeeded(parser.parseOptionalLSquare())) {
1273 if (parser.parseAttribute(attributes.emplace_back()) ||
1274 parser.parseRSquare())
1277 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1278 parser.getContext(), mlir::acc::DeviceType::None));
1284 if (failed(parser.parseRParen()))
1296 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1298 if (operands.begin() == operands.end() &&
1314 mlir::acc::CombinedConstructsTypeAttr &attr) {
1320 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1323 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1326 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1329 "expected compute construct name");
1340 mlir::acc::CombinedConstructsTypeAttr attr) {
1342 switch (attr.getValue()) {
1343 case mlir::acc::CombinedConstructsType::KernelsLoop:
1344 p <<
"combined(kernels)";
1346 case mlir::acc::CombinedConstructsType::ParallelLoop:
1347 p <<
"combined(parallel)";
1349 case mlir::acc::CombinedConstructsType::SerialLoop:
1350 p <<
"combined(serial)";
1360 unsigned SerialOp::getNumDataOperands() {
1361 return getReductionOperands().size() + getGangPrivateOperands().size() +
1362 getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
1365 Value SerialOp::getDataOperand(
unsigned i) {
1367 numOptional += getIfCond() ? 1 : 0;
1368 numOptional += getSelfCond() ? 1 : 0;
1369 return getOperand(getWaitOperands().size() + numOptional + i);
1372 bool acc::SerialOp::hasAsyncOnly() {
1376 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1384 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1389 bool acc::SerialOp::hasWaitOnly() {
1393 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1402 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1404 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1405 getHasWaitDevnum(), deviceType);
1412 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1414 getWaitOperandsSegments(), getHasWaitDevnum(),
1419 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1420 *
this, getPrivatizations(), getGangPrivateOperands(),
"private",
1421 "privatizations",
false)))
1423 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1424 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1425 "reductions",
false)))
1429 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1430 getWaitOperandsDeviceTypeAttr(),
"wait")))
1434 getAsyncOperandsDeviceTypeAttr(),
1438 if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*
this)))
1441 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
1448 unsigned KernelsOp::getNumDataOperands() {
1449 return getDataClauseOperands().size();
1452 Value KernelsOp::getDataOperand(
unsigned i) {
1454 numOptional += getWaitOperands().size();
1455 numOptional += getNumGangs().size();
1456 numOptional += getNumWorkers().size();
1457 numOptional += getVectorLength().size();
1458 numOptional += getIfCond() ? 1 : 0;
1459 numOptional += getSelfCond() ? 1 : 0;
1460 return getOperand(numOptional + i);
1463 bool acc::KernelsOp::hasAsyncOnly() {
1467 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1475 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1480 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1485 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1490 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1495 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1497 getVectorLength(), deviceType);
1505 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1507 getNumGangsSegments(), deviceType);
1510 bool acc::KernelsOp::hasWaitOnly() {
1514 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1523 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1525 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1526 getHasWaitDevnum(), deviceType);
1533 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1535 getWaitOperandsSegments(), getHasWaitDevnum(),
1541 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1542 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1546 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1547 getWaitOperandsDeviceTypeAttr(),
"wait")))
1551 getNumWorkersDeviceTypeAttr(),
1556 getVectorLengthDeviceTypeAttr(),
1561 getAsyncOperandsDeviceTypeAttr(),
1565 if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*
this)))
1568 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
1576 if (getDataClauseOperands().empty())
1577 return emitError(
"at least one operand must appear on the host_data "
1580 for (
mlir::Value operand : getDataClauseOperands())
1581 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1582 return emitError(
"expect data entry operation as defining op");
1588 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1600 bool &needCommaBetweenValues,
bool &newValue) {
1607 attributes.push_back(gangArgType);
1608 needCommaBetweenValues =
true;
1619 mlir::ArrayAttr &gangOnlyDeviceType) {
1624 bool needCommaBetweenValues =
false;
1625 bool needCommaBeforeOperands =
false;
1631 gangOnlyDeviceType =
1640 if (parser.parseAttribute(
1641 gangOnlyDeviceTypeAttributes.emplace_back()))
1648 needCommaBeforeOperands =
true;
1652 mlir::acc::GangArgType::Num);
1654 mlir::acc::GangArgType::Dim);
1656 parser.
getContext(), mlir::acc::GangArgType::Static);
1659 if (needCommaBeforeOperands) {
1660 needCommaBeforeOperands =
false;
1667 int32_t crtOperandsSize = gangOperands.size();
1669 bool newValue =
false;
1670 bool needValue =
false;
1671 if (needCommaBetweenValues) {
1679 gangOperands, gangOperandsType,
1680 gangArgTypeAttributes, argNum,
1681 needCommaBetweenValues, newValue)))
1684 gangOperands, gangOperandsType,
1685 gangArgTypeAttributes, argDim,
1686 needCommaBetweenValues, newValue)))
1688 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1689 gangOperands, gangOperandsType,
1690 gangArgTypeAttributes, argStatic,
1691 needCommaBetweenValues, newValue)))
1694 if (!newValue && needValue) {
1696 "new value expected after comma");
1704 if (gangOperands.empty())
1707 "expect at least one of num, dim or static values");
1713 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
1721 seg.push_back(gangOperands.size() - crtOperandsSize);
1729 gangArgTypeAttributes.end());
1734 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1743 std::optional<mlir::ArrayAttr> gangArgTypes,
1744 std::optional<mlir::ArrayAttr> deviceTypes,
1745 std::optional<mlir::DenseI32ArrayAttr> segments,
1746 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1748 if (operands.begin() == operands.end() &&
1763 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1765 llvm::interleaveComma(
1766 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1767 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1768 (*gangArgTypes)[opIdx]);
1769 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1770 p << LoopOp::getGangNumKeyword();
1771 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1772 p << LoopOp::getGangDimKeyword();
1773 else if (gangArgTypeAttr.getValue() ==
1774 mlir::acc::GangArgType::Static)
1775 p << LoopOp::getGangStaticKeyword();
1776 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
1787 std::optional<mlir::ArrayAttr> segments,
1788 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
1791 for (
auto attr : *segments) {
1792 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1793 if (deviceTypes.contains(deviceTypeAttr.getValue()))
1795 deviceTypes.insert(deviceTypeAttr.getValue());
1802 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
1805 for (
auto attr : deviceTypes) {
1806 auto deviceTypeAttr =
1807 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
1808 if (!deviceTypeAttr)
1810 if (crtDeviceTypes.contains(deviceTypeAttr.getValue()))
1812 crtDeviceTypes.insert(deviceTypeAttr.getValue());
1818 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
1819 (getUpperbound().size() != getInclusiveUpperbound()->size()))
1820 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
1821 <<
" as upperbound size";
1824 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
1825 return emitOpError() <<
"collapse device_type attr must be define when"
1826 <<
" collapse attr is present";
1828 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
1829 getCollapseAttr().getValue().size() !=
1830 getCollapseDeviceTypeAttr().getValue().size())
1831 return emitOpError() <<
"collapse attribute count must match collapse"
1832 <<
" device_type count";
1834 return emitOpError()
1835 <<
"duplicate device_type found in collapseDeviceType attribute";
1838 if (!getGangOperands().empty()) {
1839 if (!getGangOperandsArgType())
1840 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
1841 <<
" when gang operands are present";
1843 if (getGangOperands().size() !=
1844 getGangOperandsArgTypeAttr().getValue().size())
1845 return emitOpError() <<
"gangOperandsArgType attribute count must match"
1846 <<
" gangOperands count";
1849 return emitOpError() <<
"duplicate device_type found in gang attribute";
1852 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
1853 getGangOperandsDeviceTypeAttr(),
"gang")))
1858 return emitOpError() <<
"duplicate device_type found in worker attribute";
1860 return emitOpError() <<
"duplicate device_type found in "
1861 "workerNumOperandsDeviceType attribute";
1863 getWorkerNumOperandsDeviceTypeAttr(),
1869 return emitOpError() <<
"duplicate device_type found in vector attribute";
1871 return emitOpError() <<
"duplicate device_type found in "
1872 "vectorOperandsDeviceType attribute";
1874 getVectorOperandsDeviceTypeAttr(),
1879 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
1880 getTileOperandsDeviceTypeAttr(),
"tile")))
1884 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
1888 return emitError() <<
"only one of \"" << acc::LoopOp::getAutoAttrStrName()
1889 <<
"\", " << getIndependentAttrName() <<
", "
1891 <<
" can be present at the same time";
1896 for (
auto attr : getSeqAttr()) {
1897 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1898 if (hasVector(deviceTypeAttr.getValue()) ||
1899 getVectorValue(deviceTypeAttr.getValue()) ||
1900 hasWorker(deviceTypeAttr.getValue()) ||
1901 getWorkerValue(deviceTypeAttr.getValue()) ||
1902 hasGang(deviceTypeAttr.getValue()) ||
1903 getGangValue(mlir::acc::GangArgType::Num,
1904 deviceTypeAttr.getValue()) ||
1905 getGangValue(mlir::acc::GangArgType::Dim,
1906 deviceTypeAttr.getValue()) ||
1907 getGangValue(mlir::acc::GangArgType::Static,
1908 deviceTypeAttr.getValue()))
1910 <<
"gang, worker or vector cannot appear with the seq attr";
1914 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1915 *
this, getPrivatizations(), getPrivateOperands(),
"private",
1916 "privatizations",
false)))
1919 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1920 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1921 "reductions",
false)))
1924 if (getCombined().has_value() &&
1925 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
1926 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
1927 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
1928 return emitError(
"unexpected combined constructs attribute");
1932 if (getRegion().empty())
1933 return emitError(
"expected non-empty body.");
1938 unsigned LoopOp::getNumDataOperands() {
1939 return getReductionOperands().size() + getPrivateOperands().size();
1942 Value LoopOp::getDataOperand(
unsigned i) {
1943 unsigned numOptional =
1944 getLowerbound().size() + getUpperbound().size() + getStep().size();
1945 numOptional += getGangOperands().size();
1946 numOptional += getVectorOperands().size();
1947 numOptional += getWorkerNumOperands().size();
1948 numOptional += getTileOperands().size();
1949 numOptional += getCacheOperands().size();
1950 return getOperand(numOptional + i);
1955 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
1959 bool LoopOp::hasIndependent() {
1963 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
1969 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
1977 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
1979 getVectorOperands(), deviceType);
1984 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
1992 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
1994 getWorkerNumOperands(), deviceType);
1999 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2008 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2010 getTileOperandsSegments(), deviceType);
2013 std::optional<int64_t> LoopOp::getCollapseValue() {
2017 std::optional<int64_t>
2018 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2019 if (!getCollapseAttr())
2020 return std::nullopt;
2021 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2023 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2024 return intAttr.getValue().getZExtValue();
2026 return std::nullopt;
2029 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2033 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2034 mlir::acc::DeviceType deviceType) {
2035 if (getGangOperands().empty())
2037 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
2038 int32_t nbOperandsBefore = 0;
2039 for (
unsigned i = 0; i < *pos; ++i)
2040 nbOperandsBefore += (*getGangOperandsSegments())[i];
2043 .drop_front(nbOperandsBefore)
2044 .take_front((*getGangOperandsSegments())[*pos]);
2046 int32_t argTypeIdx = nbOperandsBefore;
2047 for (
auto value : values) {
2048 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2049 (*getGangOperandsArgType())[argTypeIdx]);
2050 if (gangArgTypeAttr.getValue() == gangArgType)
2060 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2065 return {&getRegion()};
2109 if (!regionArgs.empty()) {
2110 p << acc::LoopOp::getControlKeyword() <<
"(";
2111 llvm::interleaveComma(regionArgs, p,
2113 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
2114 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
2115 <<
" : " << stepType <<
") ";
2128 if (getOperands().empty() && !getDefaultAttr())
2129 return emitError(
"at least one operand or the default attribute "
2130 "must appear on the data operation");
2132 for (
mlir::Value operand : getDataClauseOperands())
2133 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2134 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2135 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2136 operand.getDefiningOp()))
2137 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2140 if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*
this)))
2146 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
2148 Value DataOp::getDataOperand(
unsigned i) {
2149 unsigned numOptional = getIfCond() ? 1 : 0;
2151 numOptional += getWaitOperands().size();
2152 return getOperand(numOptional + i);
2155 bool acc::DataOp::hasAsyncOnly() {
2159 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2167 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2174 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2183 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2185 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2186 getHasWaitDevnum(), deviceType);
2193 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2195 getWaitOperandsSegments(), getHasWaitDevnum(),
2207 if (getDataClauseOperands().empty())
2208 return emitError(
"at least one operand must be present in dataOperands on "
2209 "the exit data operation");
2213 if (getAsyncOperand() && getAsync())
2214 return emitError(
"async attribute cannot appear with asyncOperand");
2218 if (!getWaitOperands().empty() && getWait())
2219 return emitError(
"wait attribute cannot appear with waitOperands");
2221 if (getWaitDevnum() && getWaitOperands().empty())
2222 return emitError(
"wait_devnum cannot appear without waitOperands");
2227 unsigned ExitDataOp::getNumDataOperands() {
2228 return getDataClauseOperands().size();
2231 Value ExitDataOp::getDataOperand(
unsigned i) {
2232 unsigned numOptional = getIfCond() ? 1 : 0;
2233 numOptional += getAsyncOperand() ? 1 : 0;
2234 numOptional += getWaitDevnum() ? 1 : 0;
2235 return getOperand(getWaitOperands().size() + numOptional + i);
2240 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
2251 if (getDataClauseOperands().empty())
2252 return emitError(
"at least one operand must be present in dataOperands on "
2253 "the enter data operation");
2257 if (getAsyncOperand() && getAsync())
2258 return emitError(
"async attribute cannot appear with asyncOperand");
2262 if (!getWaitOperands().empty() && getWait())
2263 return emitError(
"wait attribute cannot appear with waitOperands");
2265 if (getWaitDevnum() && getWaitOperands().empty())
2266 return emitError(
"wait_devnum cannot appear without waitOperands");
2268 for (
mlir::Value operand : getDataClauseOperands())
2269 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2270 operand.getDefiningOp()))
2271 return emitError(
"expect data entry operation as defining op");
2276 unsigned EnterDataOp::getNumDataOperands() {
2277 return getDataClauseOperands().size();
2280 Value EnterDataOp::getDataOperand(
unsigned i) {
2281 unsigned numOptional = getIfCond() ? 1 : 0;
2282 numOptional += getAsyncOperand() ? 1 : 0;
2283 numOptional += getWaitDevnum() ? 1 : 0;
2284 return getOperand(getWaitOperands().size() + numOptional + i);
2289 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
2308 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2315 if (
Value writeVal = op.getWriteOpVal()) {
2325 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
2331 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2332 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2334 return dyn_cast<AtomicReadOp>(getSecondOp());
2337 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2338 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2340 return dyn_cast<AtomicWriteOp>(getSecondOp());
2343 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2344 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2346 return dyn_cast<AtomicUpdateOp>(getSecondOp());
2349 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
2355 template <
typename Op>
2356 static LogicalResult
2358 bool requireAtLeastOneOperand =
true) {
2359 if (operands.empty() && requireAtLeastOneOperand)
2362 "at least one operand must appear on the declare operation");
2365 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2366 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2367 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2368 operand.getDefiningOp()))
2370 "expect valid declare data entry operation or acc.getdeviceptr "
2374 assert(varPtr &&
"declare operands can only be data entry operations which "
2375 "must have varPtr");
2376 std::optional<mlir::acc::DataClause> dataClauseOptional{
2378 assert(dataClauseOptional.has_value() &&
2379 "declare operands can only be data entry operations which must have "
2383 if (!varPtr.getDefiningOp())
2387 auto declareAttribute{
2389 if (!declareAttribute)
2391 "expect declare attribute on variable in declare operation");
2393 auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2394 if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2396 "expect matching declare attribute on variable in declare operation");
2403 if (declAttr.getImplicit() &&
2406 "implicitness must match between declare op and flag on variable");
2440 acc::DeviceType dtype) {
2441 unsigned parallelism = 0;
2442 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2443 parallelism += op.hasWorker(dtype) ? 1 : 0;
2444 parallelism += op.hasVector(dtype) ? 1 : 0;
2445 parallelism += op.hasSeq(dtype) ? 1 : 0;
2450 unsigned baseParallelism =
2453 if (baseParallelism > 1)
2454 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
2455 "be present at the same time";
2457 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2459 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
2464 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2465 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
2466 "be present at the same time";
2473 mlir::ArrayAttr &deviceTypes) {
2478 if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2480 if (failed(parser.parseOptionalLSquare())) {
2481 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2482 parser.getContext(), mlir::acc::DeviceType::None));
2484 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2485 parser.parseRSquare())
2493 deviceTypes =
ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2499 std::optional<mlir::ArrayAttr> bindName,
2500 std::optional<mlir::ArrayAttr> deviceTypes) {
2501 llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2502 [&](
const auto &pair) {
2503 p << std::get<0>(pair);
2509 mlir::ArrayAttr &gang,
2510 mlir::ArrayAttr &gangDim,
2511 mlir::ArrayAttr &gangDimDeviceTypes) {
2514 gangDimDeviceTypeAttrs;
2515 bool needCommaBeforeOperands =
false;
2528 if (parser.parseAttribute(gangAttrs.emplace_back()))
2535 needCommaBeforeOperands =
true;
2538 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2542 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2543 parser.parseColon() ||
2544 parser.parseAttribute(gangDimAttrs.emplace_back()))
2546 if (succeeded(parser.parseOptionalLSquare())) {
2547 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2548 parser.parseRSquare())
2551 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2552 parser.getContext(), mlir::acc::DeviceType::None));
2558 if (failed(parser.parseRParen()))
2563 gangDimDeviceTypes =
2570 std::optional<mlir::ArrayAttr> gang,
2571 std::optional<mlir::ArrayAttr> gangDim,
2572 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2575 gang->size() == 1) {
2576 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2589 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2590 [&](
const auto &pair) {
2591 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
2592 p << std::get<0>(pair);
2600 mlir::ArrayAttr &deviceTypes) {
2613 if (parser.parseAttribute(attributes.emplace_back()))
2627 std::optional<mlir::ArrayAttr> deviceTypes) {
2630 auto deviceTypeAttr =
2631 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2641 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2649 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2655 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2661 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2665 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2669 std::optional<llvm::StringRef>
2670 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2672 return std::nullopt;
2673 if (
auto pos =
findSegment(*getBindNameDeviceType(), deviceType)) {
2674 auto attr = (*getBindName())[*pos];
2675 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2676 return stringAttr.getValue();
2678 return std::nullopt;
2683 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2687 std::optional<int64_t> RoutineOp::getGangDimValue() {
2691 std::optional<int64_t>
2692 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2694 return std::nullopt;
2695 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
2696 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2697 return intAttr.getInt();
2699 return std::nullopt;
2710 return emitOpError(
"cannot be nested in a compute operation");
2722 return emitOpError(
"cannot be nested in a compute operation");
2734 return emitOpError(
"cannot be nested in a compute operation");
2735 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2736 return emitOpError(
"at least one default_async, device_num, or device_type "
2737 "operand must appear");
2747 if (getDataClauseOperands().empty())
2748 return emitError(
"at least one value must be present in dataOperands");
2751 getAsyncOperandsDeviceTypeAttr(),
2756 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2757 getWaitOperandsDeviceTypeAttr(),
"wait")))
2760 if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*
this)))
2763 for (
mlir::Value operand : getDataClauseOperands())
2764 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2765 operand.getDefiningOp()))
2766 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2772 unsigned UpdateOp::getNumDataOperands() {
2773 return getDataClauseOperands().size();
2776 Value UpdateOp::getDataOperand(
unsigned i) {
2778 numOptional += getIfCond() ? 1 : 0;
2779 return getOperand(getWaitOperands().size() + numOptional + i);
2784 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
2787 bool UpdateOp::hasAsyncOnly() {
2791 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2799 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2809 bool UpdateOp::hasWaitOnly() {
2813 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2822 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2824 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2825 getHasWaitDevnum(), deviceType);
2832 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2834 getWaitOperandsSegments(), getHasWaitDevnum(),
2845 if (getAsyncOperand() && getAsync())
2846 return emitError(
"async attribute cannot appear with asyncOperand");
2848 if (getWaitDevnum() && getWaitOperands().empty())
2849 return emitError(
"wait_devnum cannot appear without waitOperands");
2854 #define GET_OP_CLASSES
2855 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
2857 #define GET_ATTRDEF_CLASSES
2858 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
2860 #define GET_TYPEDEF_CLASSES
2861 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
2870 [&](
auto entry) {
return entry.getVarPtr(); })
2871 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
2872 [&](
auto exit) {
return exit.getVarPtr(); })
2880 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
2889 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
2899 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
2901 dataClause.getBounds().begin(), dataClause.getBounds().end());
2913 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
2915 dataClause.getAsyncOperands().begin(),
2916 dataClause.getAsyncOperands().end());
2927 return dataClause.getAsyncOperandsDeviceTypeAttr();
2935 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
2942 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
2949 std::optional<mlir::acc::DataClause>
2954 .Case<ACC_DATA_ENTRY_OPS>(
2955 [&](
auto entry) {
return entry.getDataClause(); })
2963 [&](
auto entry) {
return entry.getImplicit(); })
2972 [&](
auto entry) {
return entry.getDataClauseOperands(); })
2974 return dataOperands;
2982 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
2984 return dataOperands;
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.