19 #include "llvm/ADT/SmallSet.h"
20 #include "llvm/ADT/TypeSwitch.h"
25 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
26 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
27 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
28 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
31 struct MemRefPointerLikeModel
32 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
35 return llvm::cast<MemRefType>(pointer).getElementType();
39 struct LLVMPointerPointerLikeModel
40 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
41 LLVM::LLVMPointerType> {
50 void OpenACCDialect::initialize() {
53 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
56 #define GET_ATTRDEF_LIST
57 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
60 #define GET_TYPEDEF_LIST
61 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
67 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
68 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
77 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
83 mlir::acc::DeviceType deviceType) {
87 for (
auto attr : *arrayAttr) {
88 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
89 if (deviceTypeAttr.getValue() == deviceType)
97 std::optional<mlir::ArrayAttr> deviceTypes) {
102 llvm::interleaveComma(*deviceTypes, p,
108 mlir::acc::DeviceType deviceType) {
109 unsigned segmentIdx = 0;
110 for (
auto attr : segments) {
111 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
112 if (deviceTypeAttr.getValue() == deviceType)
113 return std::make_optional(segmentIdx);
123 mlir::acc::DeviceType deviceType) {
125 return range.take_front(0);
126 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
127 int32_t nbOperandsBefore = 0;
128 for (
unsigned i = 0; i < *pos; ++i)
129 nbOperandsBefore += (*segments)[i];
130 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
132 return range.take_front(0);
139 std::optional<mlir::ArrayAttr> hasWaitDevnum,
140 mlir::acc::DeviceType deviceType) {
143 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
144 if (hasWaitDevnum->getValue()[*pos])
155 std::optional<mlir::ArrayAttr> hasWaitDevnum,
156 mlir::acc::DeviceType deviceType) {
161 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
162 if (hasWaitDevnum && *hasWaitDevnum) {
163 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
164 if (boolAttr.getValue())
165 return range.drop_front(1);
171 template <
typename Op>
173 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
175 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
180 op.hasAsyncOnly(dtype))
181 return op.
emitError(
"async attribute cannot appear with asyncOperand");
186 op.hasWaitOnly(dtype))
187 return op.
emitError(
"wait attribute cannot appear with waitOperands");
196 auto extent = getExtent();
197 auto upperbound = getUpperbound();
198 if (!extent && !upperbound)
199 return emitError(
"expected extent or upperbound.");
209 "data clause associated with private operation must match its intent");
218 return emitError(
"data clause associated with firstprivate operation must "
228 return emitError(
"data clause associated with reduction operation must "
238 return emitError(
"data clause associated with deviceptr operation must "
249 "data clause associated with present operation must match its intent");
258 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
263 "data clause associated with copyin operation must match its intent"
264 " or specify original clause this operation was decomposed from");
268 bool acc::CopyinOp::isCopyinReadonly() {
269 return getDataClause() == acc::DataClause::acc_copyin_readonly;
282 "data clause associated with create operation must match its intent"
283 " or specify original clause this operation was decomposed from");
287 bool acc::CreateOp::isCreateZero() {
289 return getDataClause() == acc::DataClause::acc_create_zero ||
298 return emitError(
"data clause associated with no_create operation must "
309 "data clause associated with attach operation must match its intent");
318 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
319 return emitError(
"data clause associated with device_resident operation "
320 "must match its intent");
331 "data clause associated with link operation must match its intent");
345 "data clause associated with copyout operation must match its intent"
346 " or specify original clause this operation was decomposed from");
348 return emitError(
"must have both host and device pointers");
352 bool acc::CopyoutOp::isCopyoutZero() {
367 getDataClause() != acc::DataClause::acc_declare_device_resident &&
370 "data clause associated with delete operation must match its intent"
371 " or specify original clause this operation was decomposed from");
373 return emitError(
"must have device pointer");
385 "data clause associated with detach operation must match its intent"
386 " or specify original clause this operation was decomposed from");
388 return emitError(
"must have device pointer");
400 "data clause associated with host operation must match its intent"
401 " or specify original clause this operation was decomposed from");
403 return emitError(
"must have both host and device pointers");
414 "data clause associated with device operation must match its intent"
415 " or specify original clause this operation was decomposed from");
426 "data clause associated with use_device operation must match its intent"
427 " or specify original clause this operation was decomposed from");
439 "data clause associated with cache operation must match its intent"
440 " or specify original clause this operation was decomposed from");
444 template <
typename StructureOp>
446 unsigned nRegions = 1) {
449 for (
unsigned i = 0; i < nRegions; ++i)
450 regions.push_back(state.addRegion());
452 for (
Region *region : regions)
460 return isa<acc::ParallelOp, acc::LoopOp>(op);
467 template <
typename OpTy>
471 LogicalResult matchAndRewrite(OpTy op,
474 Value ifCond = op.getIfCond();
478 IntegerAttr constAttr;
481 if (constAttr.getInt())
482 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
494 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
506 template <
typename OpTy>
507 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
510 LogicalResult matchAndRewrite(OpTy op,
513 Value ifCond = op.getIfCond();
517 IntegerAttr constAttr;
520 if (constAttr.getInt())
521 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
536 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
538 if (optional && region.
empty())
542 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
546 return op->
emitOpError() <<
"expects " << regionName
549 << regionType <<
" type";
552 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
553 if (yieldOp.getOperands().size() != 1 ||
554 yieldOp.getOperands().getTypes()[0] != type)
555 return op->
emitOpError() <<
"expects " << regionName
557 "yield a value of the "
558 << regionType <<
" type";
564 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
566 "privatization",
"init",
getType(),
570 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
580 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
582 "privatization",
"init",
getType(),
586 if (getCopyRegion().empty())
587 return emitOpError() <<
"expects non-empty copy region";
592 return emitOpError() <<
"expects copy region with two arguments of the "
593 "privatization type";
595 if (getDestroyRegion().empty())
599 "privatization",
"destroy",
610 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
616 if (getCombinerRegion().empty())
617 return emitOpError() <<
"expects non-empty combiner region";
619 Block &reductionBlock = getCombinerRegion().
front();
623 return emitOpError() <<
"expects combiner region with the first two "
624 <<
"arguments of the reduction type";
626 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
627 if (yieldOp.getOperands().size() != 1 ||
628 yieldOp.getOperands().getTypes()[0] !=
getType())
629 return emitOpError() <<
"expects combiner region to yield a value "
630 "of the reduction type";
646 if (parser.parseAttribute(attributes.emplace_back()) ||
647 parser.parseArrow() ||
648 parser.parseOperand(operands.emplace_back()) ||
649 parser.parseColonType(types.emplace_back()))
663 std::optional<mlir::ArrayAttr> attributes) {
664 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
665 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
666 << std::get<1>(it).getType();
675 template <
typename Op>
679 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
680 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
681 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
682 operand.getDefiningOp()))
684 "expect data entry/exit operation or acc.getdeviceptr "
689 template <
typename Op>
693 llvm::StringRef symbolName,
bool checkOperandType =
true) {
694 if (!operands.empty()) {
695 if (!attributes || attributes->size() != operands.size())
697 <<
"expected as many " << symbolName <<
" symbol reference as "
698 << operandName <<
" operands";
702 <<
"unexpected " << symbolName <<
" symbol reference";
707 for (
auto args : llvm::zip(operands, *attributes)) {
710 if (!set.insert(operand).second)
712 << operandName <<
" operand appears more than once";
715 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
716 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
719 <<
"expected symbol reference " << symbolRef <<
" to point to a "
720 << operandName <<
" declaration";
722 if (checkOperandType && decl.getType() && decl.getType() != varType)
723 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
724 <<
") to be the same type as " << operandName
725 <<
" declaration (" << decl.getType() <<
")";
731 unsigned ParallelOp::getNumDataOperands() {
732 return getReductionOperands().size() + getGangPrivateOperands().size() +
733 getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
736 Value ParallelOp::getDataOperand(
unsigned i) {
738 numOptional += getNumGangs().size();
739 numOptional += getNumWorkers().size();
740 numOptional += getVectorLength().size();
741 numOptional += getIfCond() ? 1 : 0;
742 numOptional += getSelfCond() ? 1 : 0;
743 return getOperand(getWaitOperands().size() + numOptional + i);
746 template <
typename Op>
748 ArrayAttr deviceTypes,
749 llvm::StringRef keyword) {
750 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
751 return op.
emitOpError() << keyword <<
" operands count must match "
752 << keyword <<
" device_type count";
756 template <
typename Op>
759 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
760 std::size_t numOperandsInSegments = 0;
766 if (maxInSegment != 0 && segCount > maxInSegment)
767 return op.
emitOpError() << keyword <<
" expects a maximum of "
768 << maxInSegment <<
" values per segment";
769 numOperandsInSegments += segCount;
771 if (numOperandsInSegments != operands.size())
773 << keyword <<
" operand count does not match count in segments";
774 if (deviceTypes.getValue().size() != (
size_t)segments.size())
776 << keyword <<
" segment count does not match device_type count";
781 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
782 *
this, getPrivatizations(), getGangPrivateOperands(),
"private",
783 "privatizations",
false)))
785 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
786 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
787 "reductions",
false)))
791 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
792 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
796 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
797 getWaitOperandsDeviceTypeAttr(),
"wait")))
801 getNumWorkersDeviceTypeAttr(),
806 getVectorLengthDeviceTypeAttr(),
811 getAsyncOperandsDeviceTypeAttr(),
815 if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*
this)))
818 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
824 mlir::acc::DeviceType deviceType) {
827 if (
auto pos =
findSegment(*arrayAttr, deviceType))
832 bool acc::ParallelOp::hasAsyncOnly() {
836 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
844 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
849 mlir::Value acc::ParallelOp::getNumWorkersValue() {
854 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
859 mlir::Value acc::ParallelOp::getVectorLengthValue() {
864 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
866 getVectorLength(), deviceType);
874 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
876 getNumGangsSegments(), deviceType);
879 bool acc::ParallelOp::hasWaitOnly() {
883 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
892 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
894 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
895 getHasWaitDevnum(), deviceType);
902 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
904 getWaitOperandsSegments(), getHasWaitDevnum(),
920 odsBuilder, odsState, asyncOperands,
nullptr,
921 nullptr, waitOperands,
nullptr,
923 nullptr, numGangs,
nullptr,
925 nullptr, vectorLength,
926 nullptr, ifCond, selfCond,
927 nullptr, reductionOperands,
nullptr,
928 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
929 nullptr, dataClauseOperands,
945 int32_t crtOperandsSize = operands.size();
948 if (parser.parseOperand(operands.emplace_back()) ||
949 parser.parseColonType(types.emplace_back()))
954 seg.push_back(operands.size() - crtOperandsSize);
978 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
980 p <<
" [" << attr <<
"]";
985 std::optional<mlir::ArrayAttr> deviceTypes,
986 std::optional<mlir::DenseI32ArrayAttr> segments) {
990 llvm::interleaveComma(
991 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
992 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1012 int32_t crtOperandsSize = operands.size();
1016 if (parser.parseOperand(operands.emplace_back()) ||
1017 parser.parseColonType(types.emplace_back()))
1023 seg.push_back(operands.size() - crtOperandsSize);
1049 std::optional<mlir::DenseI32ArrayAttr> segments) {
1051 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1053 llvm::interleaveComma(
1054 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1055 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1068 mlir::ArrayAttr &keywordOnly) {
1072 bool needCommaBeforeOperands =
false;
1085 if (parser.parseAttribute(keywordAttrs.emplace_back()))
1092 needCommaBeforeOperands =
true;
1095 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1102 int32_t crtOperandsSize = operands.size();
1114 if (parser.parseOperand(operands.emplace_back()) ||
1115 parser.parseColonType(types.emplace_back()))
1121 seg.push_back(operands.size() - crtOperandsSize);
1150 if (attrs->size() != 1)
1152 if (
auto deviceTypeAttr =
1153 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1160 std::optional<mlir::ArrayAttr> deviceTypes,
1161 std::optional<mlir::DenseI32ArrayAttr> segments,
1162 std::optional<mlir::ArrayAttr> hasDevNum,
1163 std::optional<mlir::ArrayAttr> keywordOnly) {
1175 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1177 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1178 if (boolAttr && boolAttr.getValue())
1180 llvm::interleaveComma(
1181 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1182 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1198 if (parser.parseOperand(operands.emplace_back()) ||
1199 parser.parseColonType(types.emplace_back()))
1201 if (succeeded(parser.parseOptionalLSquare())) {
1202 if (parser.parseAttribute(attributes.emplace_back()) ||
1203 parser.parseRSquare())
1206 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1207 parser.getContext(), mlir::acc::DeviceType::None));
1221 std::optional<mlir::ArrayAttr> deviceTypes) {
1224 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
1225 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
1234 mlir::ArrayAttr &keywordOnlyDeviceType) {
1237 bool needCommaBeforeOperands =
false;
1243 keywordOnlyDeviceType =
1252 if (parser.parseAttribute(
1253 keywordOnlyDeviceTypeAttributes.emplace_back()))
1260 needCommaBeforeOperands =
true;
1263 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1268 if (parser.parseOperand(operands.emplace_back()) ||
1269 parser.parseColonType(types.emplace_back()))
1271 if (succeeded(parser.parseOptionalLSquare())) {
1272 if (parser.parseAttribute(attributes.emplace_back()) ||
1273 parser.parseRSquare())
1276 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1277 parser.getContext(), mlir::acc::DeviceType::None));
1283 if (failed(parser.parseRParen()))
1295 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1297 if (operands.begin() == operands.end() &&
1313 mlir::acc::CombinedConstructsTypeAttr &attr) {
1319 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1322 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1325 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1328 "expected compute construct name");
1339 mlir::acc::CombinedConstructsTypeAttr attr) {
1341 switch (attr.getValue()) {
1342 case mlir::acc::CombinedConstructsType::KernelsLoop:
1343 p <<
"combined(kernels)";
1345 case mlir::acc::CombinedConstructsType::ParallelLoop:
1346 p <<
"combined(parallel)";
1348 case mlir::acc::CombinedConstructsType::SerialLoop:
1349 p <<
"combined(serial)";
1359 unsigned SerialOp::getNumDataOperands() {
1360 return getReductionOperands().size() + getGangPrivateOperands().size() +
1361 getGangFirstPrivateOperands().size() + getDataClauseOperands().size();
1364 Value SerialOp::getDataOperand(
unsigned i) {
1366 numOptional += getIfCond() ? 1 : 0;
1367 numOptional += getSelfCond() ? 1 : 0;
1368 return getOperand(getWaitOperands().size() + numOptional + i);
1371 bool acc::SerialOp::hasAsyncOnly() {
1375 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1383 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1388 bool acc::SerialOp::hasWaitOnly() {
1392 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1401 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1403 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1404 getHasWaitDevnum(), deviceType);
1411 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1413 getWaitOperandsSegments(), getHasWaitDevnum(),
1418 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1419 *
this, getPrivatizations(), getGangPrivateOperands(),
"private",
1420 "privatizations",
false)))
1422 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1423 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1424 "reductions",
false)))
1428 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1429 getWaitOperandsDeviceTypeAttr(),
"wait")))
1433 getAsyncOperandsDeviceTypeAttr(),
1437 if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*
this)))
1440 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
1447 unsigned KernelsOp::getNumDataOperands() {
1448 return getDataClauseOperands().size();
1451 Value KernelsOp::getDataOperand(
unsigned i) {
1453 numOptional += getWaitOperands().size();
1454 numOptional += getNumGangs().size();
1455 numOptional += getNumWorkers().size();
1456 numOptional += getVectorLength().size();
1457 numOptional += getIfCond() ? 1 : 0;
1458 numOptional += getSelfCond() ? 1 : 0;
1459 return getOperand(numOptional + i);
1462 bool acc::KernelsOp::hasAsyncOnly() {
1466 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1474 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1479 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1484 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1489 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1494 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1496 getVectorLength(), deviceType);
1504 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1506 getNumGangsSegments(), deviceType);
1509 bool acc::KernelsOp::hasWaitOnly() {
1513 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1522 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1524 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1525 getHasWaitDevnum(), deviceType);
1532 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1534 getWaitOperandsSegments(), getHasWaitDevnum(),
1540 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1541 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1545 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1546 getWaitOperandsDeviceTypeAttr(),
"wait")))
1550 getNumWorkersDeviceTypeAttr(),
1555 getVectorLengthDeviceTypeAttr(),
1560 getAsyncOperandsDeviceTypeAttr(),
1564 if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*
this)))
1567 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
1575 if (getDataClauseOperands().empty())
1576 return emitError(
"at least one operand must appear on the host_data "
1579 for (
mlir::Value operand : getDataClauseOperands())
1580 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1581 return emitError(
"expect data entry operation as defining op");
1587 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1599 bool &needCommaBetweenValues,
bool &newValue) {
1606 attributes.push_back(gangArgType);
1607 needCommaBetweenValues =
true;
1618 mlir::ArrayAttr &gangOnlyDeviceType) {
1623 bool needCommaBetweenValues =
false;
1624 bool needCommaBeforeOperands =
false;
1630 gangOnlyDeviceType =
1639 if (parser.parseAttribute(
1640 gangOnlyDeviceTypeAttributes.emplace_back()))
1647 needCommaBeforeOperands =
true;
1651 mlir::acc::GangArgType::Num);
1653 mlir::acc::GangArgType::Dim);
1655 parser.
getContext(), mlir::acc::GangArgType::Static);
1658 if (needCommaBeforeOperands) {
1659 needCommaBeforeOperands =
false;
1666 int32_t crtOperandsSize = gangOperands.size();
1668 bool newValue =
false;
1669 bool needValue =
false;
1670 if (needCommaBetweenValues) {
1678 gangOperands, gangOperandsType,
1679 gangArgTypeAttributes, argNum,
1680 needCommaBetweenValues, newValue)))
1683 gangOperands, gangOperandsType,
1684 gangArgTypeAttributes, argDim,
1685 needCommaBetweenValues, newValue)))
1687 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1688 gangOperands, gangOperandsType,
1689 gangArgTypeAttributes, argStatic,
1690 needCommaBetweenValues, newValue)))
1693 if (!newValue && needValue) {
1695 "new value expected after comma");
1703 if (gangOperands.empty())
1706 "expect at least one of num, dim or static values");
1712 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
1720 seg.push_back(gangOperands.size() - crtOperandsSize);
1728 gangArgTypeAttributes.end());
1733 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1742 std::optional<mlir::ArrayAttr> gangArgTypes,
1743 std::optional<mlir::ArrayAttr> deviceTypes,
1744 std::optional<mlir::DenseI32ArrayAttr> segments,
1745 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1747 if (operands.begin() == operands.end() &&
1762 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1764 llvm::interleaveComma(
1765 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1766 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1767 (*gangArgTypes)[opIdx]);
1768 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1769 p << LoopOp::getGangNumKeyword();
1770 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1771 p << LoopOp::getGangDimKeyword();
1772 else if (gangArgTypeAttr.getValue() ==
1773 mlir::acc::GangArgType::Static)
1774 p << LoopOp::getGangStaticKeyword();
1775 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
1786 std::optional<mlir::ArrayAttr> segments,
1787 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
1790 for (
auto attr : *segments) {
1791 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1792 if (deviceTypes.contains(deviceTypeAttr.getValue()))
1794 deviceTypes.insert(deviceTypeAttr.getValue());
1801 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
1804 for (
auto attr : deviceTypes) {
1805 auto deviceTypeAttr =
1806 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
1807 if (!deviceTypeAttr)
1809 if (crtDeviceTypes.contains(deviceTypeAttr.getValue()))
1811 crtDeviceTypes.insert(deviceTypeAttr.getValue());
1817 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
1818 (getUpperbound().size() != getInclusiveUpperbound()->size()))
1819 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
1820 <<
" as upperbound size";
1823 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
1824 return emitOpError() <<
"collapse device_type attr must be define when"
1825 <<
" collapse attr is present";
1827 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
1828 getCollapseAttr().getValue().size() !=
1829 getCollapseDeviceTypeAttr().getValue().size())
1830 return emitOpError() <<
"collapse attribute count must match collapse"
1831 <<
" device_type count";
1833 return emitOpError()
1834 <<
"duplicate device_type found in collapseDeviceType attribute";
1837 if (!getGangOperands().empty()) {
1838 if (!getGangOperandsArgType())
1839 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
1840 <<
" when gang operands are present";
1842 if (getGangOperands().size() !=
1843 getGangOperandsArgTypeAttr().getValue().size())
1844 return emitOpError() <<
"gangOperandsArgType attribute count must match"
1845 <<
" gangOperands count";
1848 return emitOpError() <<
"duplicate device_type found in gang attribute";
1851 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
1852 getGangOperandsDeviceTypeAttr(),
"gang")))
1857 return emitOpError() <<
"duplicate device_type found in worker attribute";
1859 return emitOpError() <<
"duplicate device_type found in "
1860 "workerNumOperandsDeviceType attribute";
1862 getWorkerNumOperandsDeviceTypeAttr(),
1868 return emitOpError() <<
"duplicate device_type found in vector attribute";
1870 return emitOpError() <<
"duplicate device_type found in "
1871 "vectorOperandsDeviceType attribute";
1873 getVectorOperandsDeviceTypeAttr(),
1878 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
1879 getTileOperandsDeviceTypeAttr(),
"tile")))
1883 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
1887 return emitError() <<
"only one of \"" << acc::LoopOp::getAutoAttrStrName()
1888 <<
"\", " << getIndependentAttrName() <<
", "
1890 <<
" can be present at the same time";
1895 for (
auto attr : getSeqAttr()) {
1896 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1897 if (hasVector(deviceTypeAttr.getValue()) ||
1898 getVectorValue(deviceTypeAttr.getValue()) ||
1899 hasWorker(deviceTypeAttr.getValue()) ||
1900 getWorkerValue(deviceTypeAttr.getValue()) ||
1901 hasGang(deviceTypeAttr.getValue()) ||
1902 getGangValue(mlir::acc::GangArgType::Num,
1903 deviceTypeAttr.getValue()) ||
1904 getGangValue(mlir::acc::GangArgType::Dim,
1905 deviceTypeAttr.getValue()) ||
1906 getGangValue(mlir::acc::GangArgType::Static,
1907 deviceTypeAttr.getValue()))
1909 <<
"gang, worker or vector cannot appear with the seq attr";
1913 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1914 *
this, getPrivatizations(), getPrivateOperands(),
"private",
1915 "privatizations",
false)))
1918 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1919 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1920 "reductions",
false)))
1923 if (getCombined().has_value() &&
1924 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
1925 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
1926 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
1927 return emitError(
"unexpected combined constructs attribute");
1931 if (getRegion().empty())
1932 return emitError(
"expected non-empty body.");
1937 unsigned LoopOp::getNumDataOperands() {
1938 return getReductionOperands().size() + getPrivateOperands().size();
1941 Value LoopOp::getDataOperand(
unsigned i) {
1942 unsigned numOptional =
1943 getLowerbound().size() + getUpperbound().size() + getStep().size();
1944 numOptional += getGangOperands().size();
1945 numOptional += getVectorOperands().size();
1946 numOptional += getWorkerNumOperands().size();
1947 numOptional += getTileOperands().size();
1948 numOptional += getCacheOperands().size();
1949 return getOperand(numOptional + i);
1954 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
1958 bool LoopOp::hasIndependent() {
1962 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
1968 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
1976 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
1978 getVectorOperands(), deviceType);
1983 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
1991 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
1993 getWorkerNumOperands(), deviceType);
1998 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2007 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2009 getTileOperandsSegments(), deviceType);
2012 std::optional<int64_t> LoopOp::getCollapseValue() {
2016 std::optional<int64_t>
2017 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2018 if (!getCollapseAttr())
2019 return std::nullopt;
2020 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2022 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2023 return intAttr.getValue().getZExtValue();
2025 return std::nullopt;
2028 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2032 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2033 mlir::acc::DeviceType deviceType) {
2034 if (getGangOperands().empty())
2036 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
2037 int32_t nbOperandsBefore = 0;
2038 for (
unsigned i = 0; i < *pos; ++i)
2039 nbOperandsBefore += (*getGangOperandsSegments())[i];
2042 .drop_front(nbOperandsBefore)
2043 .take_front((*getGangOperandsSegments())[*pos]);
2045 int32_t argTypeIdx = nbOperandsBefore;
2046 for (
auto value : values) {
2047 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2048 (*getGangOperandsArgType())[argTypeIdx]);
2049 if (gangArgTypeAttr.getValue() == gangArgType)
2059 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2064 return {&getRegion()};
2108 if (!regionArgs.empty()) {
2109 p << acc::LoopOp::getControlKeyword() <<
"(";
2110 llvm::interleaveComma(regionArgs, p,
2112 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
2113 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
2114 <<
" : " << stepType <<
") ";
2127 if (getOperands().empty() && !getDefaultAttr())
2128 return emitError(
"at least one operand or the default attribute "
2129 "must appear on the data operation");
2131 for (
mlir::Value operand : getDataClauseOperands())
2132 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2133 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2134 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2135 operand.getDefiningOp()))
2136 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2139 if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*
this)))
2145 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
2147 Value DataOp::getDataOperand(
unsigned i) {
2148 unsigned numOptional = getIfCond() ? 1 : 0;
2150 numOptional += getWaitOperands().size();
2151 return getOperand(numOptional + i);
2154 bool acc::DataOp::hasAsyncOnly() {
2158 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2166 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2173 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2182 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2184 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2185 getHasWaitDevnum(), deviceType);
2192 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2194 getWaitOperandsSegments(), getHasWaitDevnum(),
2206 if (getDataClauseOperands().empty())
2207 return emitError(
"at least one operand must be present in dataOperands on "
2208 "the exit data operation");
2212 if (getAsyncOperand() && getAsync())
2213 return emitError(
"async attribute cannot appear with asyncOperand");
2217 if (!getWaitOperands().empty() && getWait())
2218 return emitError(
"wait attribute cannot appear with waitOperands");
2220 if (getWaitDevnum() && getWaitOperands().empty())
2221 return emitError(
"wait_devnum cannot appear without waitOperands");
2226 unsigned ExitDataOp::getNumDataOperands() {
2227 return getDataClauseOperands().size();
2230 Value ExitDataOp::getDataOperand(
unsigned i) {
2231 unsigned numOptional = getIfCond() ? 1 : 0;
2232 numOptional += getAsyncOperand() ? 1 : 0;
2233 numOptional += getWaitDevnum() ? 1 : 0;
2234 return getOperand(getWaitOperands().size() + numOptional + i);
2239 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
2250 if (getDataClauseOperands().empty())
2251 return emitError(
"at least one operand must be present in dataOperands on "
2252 "the enter data operation");
2256 if (getAsyncOperand() && getAsync())
2257 return emitError(
"async attribute cannot appear with asyncOperand");
2261 if (!getWaitOperands().empty() && getWait())
2262 return emitError(
"wait attribute cannot appear with waitOperands");
2264 if (getWaitDevnum() && getWaitOperands().empty())
2265 return emitError(
"wait_devnum cannot appear without waitOperands");
2267 for (
mlir::Value operand : getDataClauseOperands())
2268 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2269 operand.getDefiningOp()))
2270 return emitError(
"expect data entry operation as defining op");
2275 unsigned EnterDataOp::getNumDataOperands() {
2276 return getDataClauseOperands().size();
2279 Value EnterDataOp::getDataOperand(
unsigned i) {
2280 unsigned numOptional = getIfCond() ? 1 : 0;
2281 numOptional += getAsyncOperand() ? 1 : 0;
2282 numOptional += getWaitDevnum() ? 1 : 0;
2283 return getOperand(getWaitOperands().size() + numOptional + i);
2288 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
2307 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2314 if (
Value writeVal = op.getWriteOpVal()) {
2324 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
2330 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2331 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2333 return dyn_cast<AtomicReadOp>(getSecondOp());
2336 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2337 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2339 return dyn_cast<AtomicWriteOp>(getSecondOp());
2342 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2343 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2345 return dyn_cast<AtomicUpdateOp>(getSecondOp());
2348 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
2354 template <
typename Op>
2355 static LogicalResult
2357 bool requireAtLeastOneOperand =
true) {
2358 if (operands.empty() && requireAtLeastOneOperand)
2361 "at least one operand must appear on the declare operation");
2364 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2365 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2366 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2367 operand.getDefiningOp()))
2369 "expect valid declare data entry operation or acc.getdeviceptr "
2373 assert(varPtr &&
"declare operands can only be data entry operations which "
2374 "must have varPtr");
2375 std::optional<mlir::acc::DataClause> dataClauseOptional{
2377 assert(dataClauseOptional.has_value() &&
2378 "declare operands can only be data entry operations which must have "
2382 if (!varPtr.getDefiningOp())
2386 auto declareAttribute{
2388 if (!declareAttribute)
2390 "expect declare attribute on variable in declare operation");
2392 auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2393 if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2395 "expect matching declare attribute on variable in declare operation");
2402 if (declAttr.getImplicit() &&
2405 "implicitness must match between declare op and flag on variable");
2439 acc::DeviceType dtype) {
2440 unsigned parallelism = 0;
2441 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2442 parallelism += op.hasWorker(dtype) ? 1 : 0;
2443 parallelism += op.hasVector(dtype) ? 1 : 0;
2444 parallelism += op.hasSeq(dtype) ? 1 : 0;
2449 unsigned baseParallelism =
2452 if (baseParallelism > 1)
2453 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
2454 "be present at the same time";
2456 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2458 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
2463 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2464 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
2465 "be present at the same time";
2472 mlir::ArrayAttr &deviceTypes) {
2477 if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2479 if (failed(parser.parseOptionalLSquare())) {
2480 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2481 parser.getContext(), mlir::acc::DeviceType::None));
2483 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2484 parser.parseRSquare())
2492 deviceTypes =
ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2498 std::optional<mlir::ArrayAttr> bindName,
2499 std::optional<mlir::ArrayAttr> deviceTypes) {
2500 llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2501 [&](
const auto &pair) {
2502 p << std::get<0>(pair);
2508 mlir::ArrayAttr &gang,
2509 mlir::ArrayAttr &gangDim,
2510 mlir::ArrayAttr &gangDimDeviceTypes) {
2513 gangDimDeviceTypeAttrs;
2514 bool needCommaBeforeOperands =
false;
2527 if (parser.parseAttribute(gangAttrs.emplace_back()))
2534 needCommaBeforeOperands =
true;
2537 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2541 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2542 parser.parseColon() ||
2543 parser.parseAttribute(gangDimAttrs.emplace_back()))
2545 if (succeeded(parser.parseOptionalLSquare())) {
2546 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2547 parser.parseRSquare())
2550 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2551 parser.getContext(), mlir::acc::DeviceType::None));
2557 if (failed(parser.parseRParen()))
2562 gangDimDeviceTypes =
2569 std::optional<mlir::ArrayAttr> gang,
2570 std::optional<mlir::ArrayAttr> gangDim,
2571 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2574 gang->size() == 1) {
2575 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2588 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2589 [&](
const auto &pair) {
2590 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
2591 p << std::get<0>(pair);
2599 mlir::ArrayAttr &deviceTypes) {
2612 if (parser.parseAttribute(attributes.emplace_back()))
2626 std::optional<mlir::ArrayAttr> deviceTypes) {
2629 auto deviceTypeAttr =
2630 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2640 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2648 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2654 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2660 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2664 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2668 std::optional<llvm::StringRef>
2669 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2671 return std::nullopt;
2672 if (
auto pos =
findSegment(*getBindNameDeviceType(), deviceType)) {
2673 auto attr = (*getBindName())[*pos];
2674 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2675 return stringAttr.getValue();
2677 return std::nullopt;
2682 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2686 std::optional<int64_t> RoutineOp::getGangDimValue() {
2690 std::optional<int64_t>
2691 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2693 return std::nullopt;
2694 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
2695 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2696 return intAttr.getInt();
2698 return std::nullopt;
2709 return emitOpError(
"cannot be nested in a compute operation");
2721 return emitOpError(
"cannot be nested in a compute operation");
2733 return emitOpError(
"cannot be nested in a compute operation");
2734 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2735 return emitOpError(
"at least one default_async, device_num, or device_type "
2736 "operand must appear");
2746 if (getDataClauseOperands().empty())
2747 return emitError(
"at least one value must be present in dataOperands");
2750 getAsyncOperandsDeviceTypeAttr(),
2755 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2756 getWaitOperandsDeviceTypeAttr(),
"wait")))
2759 if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*
this)))
2762 for (
mlir::Value operand : getDataClauseOperands())
2763 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2764 operand.getDefiningOp()))
2765 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2771 unsigned UpdateOp::getNumDataOperands() {
2772 return getDataClauseOperands().size();
2775 Value UpdateOp::getDataOperand(
unsigned i) {
2777 numOptional += getIfCond() ? 1 : 0;
2778 return getOperand(getWaitOperands().size() + numOptional + i);
2783 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
2786 bool UpdateOp::hasAsyncOnly() {
2790 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2798 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2808 bool UpdateOp::hasWaitOnly() {
2812 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2821 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2823 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2824 getHasWaitDevnum(), deviceType);
2831 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2833 getWaitOperandsSegments(), getHasWaitDevnum(),
2844 if (getAsyncOperand() && getAsync())
2845 return emitError(
"async attribute cannot appear with asyncOperand");
2847 if (getWaitDevnum() && getWaitOperands().empty())
2848 return emitError(
"wait_devnum cannot appear without waitOperands");
2853 #define GET_OP_CLASSES
2854 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
2856 #define GET_ATTRDEF_CLASSES
2857 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
2859 #define GET_TYPEDEF_CLASSES
2860 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
2869 [&](
auto entry) {
return entry.getVarPtr(); })
2870 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
2871 [&](
auto exit) {
return exit.getVarPtr(); })
2879 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
2888 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
2898 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
2900 dataClause.getBounds().begin(), dataClause.getBounds().end());
2912 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
2914 dataClause.getAsyncOperands().begin(),
2915 dataClause.getAsyncOperands().end());
2926 return dataClause.getAsyncOperandsDeviceTypeAttr();
2934 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
2941 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
2948 std::optional<mlir::acc::DataClause>
2953 .Case<ACC_DATA_ENTRY_OPS>(
2954 [&](
auto entry) {
return entry.getDataClause(); })
2962 [&](
auto entry) {
return entry.getImplicit(); })
2971 [&](
auto entry) {
return entry.getDataClauseOperands(); })
2973 return dataOperands;
2981 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
2983 return dataOperands;
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
This class helps build Operations.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Location getLoc()
The source location the operation was defined or derived from.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
mlir::Value getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtr from a data clause operation.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
mlir::Value getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accPtr from a data clause operation.
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.