21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
28 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
35 struct MemRefPointerLikeModel
36 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
39 return llvm::cast<MemRefType>(pointer).getElementType();
43 struct LLVMPointerPointerLikeModel
44 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
45 LLVM::LLVMPointerType> {
54 void OpenACCDialect::initialize() {
57 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
60 #define GET_ATTRDEF_LIST
61 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
64 #define GET_TYPEDEF_LIST
65 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
71 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
72 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
81 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
87 mlir::acc::DeviceType deviceType) {
91 for (
auto attr : *arrayAttr) {
92 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
93 if (deviceTypeAttr.getValue() == deviceType)
101 std::optional<mlir::ArrayAttr> deviceTypes) {
106 llvm::interleaveComma(*deviceTypes, p,
112 mlir::acc::DeviceType deviceType) {
113 unsigned segmentIdx = 0;
114 for (
auto attr : segments) {
115 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
116 if (deviceTypeAttr.getValue() == deviceType)
117 return std::make_optional(segmentIdx);
127 mlir::acc::DeviceType deviceType) {
129 return range.take_front(0);
130 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
131 int32_t nbOperandsBefore = 0;
132 for (
unsigned i = 0; i < *pos; ++i)
133 nbOperandsBefore += (*segments)[i];
134 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
136 return range.take_front(0);
143 std::optional<mlir::ArrayAttr> hasWaitDevnum,
144 mlir::acc::DeviceType deviceType) {
147 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
148 if (hasWaitDevnum->getValue()[*pos])
159 std::optional<mlir::ArrayAttr> hasWaitDevnum,
160 mlir::acc::DeviceType deviceType) {
165 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
166 if (hasWaitDevnum && *hasWaitDevnum) {
167 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
168 if (boolAttr.getValue())
169 return range.drop_front(1);
175 template <
typename Op>
177 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
179 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
184 op.hasAsyncOnly(dtype))
185 return op.
emitError(
"async attribute cannot appear with asyncOperand");
190 op.hasWaitOnly(dtype))
191 return op.
emitError(
"wait attribute cannot appear with waitOperands");
196 template <
typename Op>
199 return op.
emitError(
"must have var operand");
201 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
202 mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
207 return op.
emitError(
"var must be mappable or pointer-like (not both)");
210 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
211 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
212 return op.
emitError(
"var must be mappable or pointer-like");
214 if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
215 op.getVarType() != op.getVar().getType())
216 return op.
emitError(
"varType must match when var is mappable");
221 template <
typename Op>
223 if (op.getVar().getType() != op.getAccVar().getType())
224 return op.
emitError(
"input and output types must match");
246 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
267 if (failed(parser.
parseType(accVarType)))
277 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
289 mlir::TypeAttr &varTypeAttr) {
290 if (failed(parser.
parseType(varPtrType)))
306 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
308 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
317 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
325 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
326 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
328 if (typeToCheckAgainst != varType) {
339 auto extent = getExtent();
340 auto upperbound = getUpperbound();
341 if (!extent && !upperbound)
342 return emitError(
"expected extent or upperbound.");
352 "data clause associated with private operation must match its intent");
363 return emitError(
"data clause associated with firstprivate operation must "
375 return emitError(
"data clause associated with reduction operation must "
387 return emitError(
"data clause associated with deviceptr operation must "
402 "data clause associated with present operation must match its intent");
415 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
420 "data clause associated with copyin operation must match its intent"
421 " or specify original clause this operation was decomposed from");
429 bool acc::CopyinOp::isCopyinReadonly() {
430 return getDataClause() == acc::DataClause::acc_copyin_readonly;
443 "data clause associated with create operation must match its intent"
444 " or specify original clause this operation was decomposed from");
452 bool acc::CreateOp::isCreateZero() {
454 return getDataClause() == acc::DataClause::acc_create_zero ||
463 return emitError(
"data clause associated with no_create operation must "
478 "data clause associated with attach operation must match its intent");
491 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
492 return emitError(
"data clause associated with device_resident operation "
493 "must match its intent");
508 "data clause associated with link operation must match its intent");
526 "data clause associated with copyout operation must match its intent"
527 " or specify original clause this operation was decomposed from");
529 return emitError(
"must have both host and device pointers");
537 bool acc::CopyoutOp::isCopyoutZero() {
552 getDataClause() != acc::DataClause::acc_declare_device_resident &&
555 "data clause associated with delete operation must match its intent"
556 " or specify original clause this operation was decomposed from");
558 return emitError(
"must have device pointer");
570 "data clause associated with detach operation must match its intent"
571 " or specify original clause this operation was decomposed from");
573 return emitError(
"must have device pointer");
585 "data clause associated with host operation must match its intent"
586 " or specify original clause this operation was decomposed from");
588 return emitError(
"must have both host and device pointers");
603 "data clause associated with device operation must match its intent"
604 " or specify original clause this operation was decomposed from");
619 "data clause associated with use_device operation must match its intent"
620 " or specify original clause this operation was decomposed from");
636 "data clause associated with cache operation must match its intent"
637 " or specify original clause this operation was decomposed from");
645 template <
typename StructureOp>
647 unsigned nRegions = 1) {
650 for (
unsigned i = 0; i < nRegions; ++i)
651 regions.push_back(state.addRegion());
653 for (
Region *region : regions)
661 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
668 template <
typename OpTy>
672 LogicalResult matchAndRewrite(OpTy op,
675 Value ifCond = op.getIfCond();
679 IntegerAttr constAttr;
682 if (constAttr.getInt())
683 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
695 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
707 template <
typename OpTy>
708 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
711 LogicalResult matchAndRewrite(OpTy op,
714 Value ifCond = op.getIfCond();
718 IntegerAttr constAttr;
721 if (constAttr.getInt())
722 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
737 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
739 if (optional && region.
empty())
743 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
747 return op->
emitOpError() <<
"expects " << regionName
750 << regionType <<
" type";
753 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
754 if (yieldOp.getOperands().size() != 1 ||
755 yieldOp.getOperands().getTypes()[0] != type)
756 return op->
emitOpError() <<
"expects " << regionName
758 "yield a value of the "
759 << regionType <<
" type";
765 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
767 "privatization",
"init",
getType(),
771 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
781 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
783 "privatization",
"init",
getType(),
787 if (getCopyRegion().empty())
788 return emitOpError() <<
"expects non-empty copy region";
793 return emitOpError() <<
"expects copy region with two arguments of the "
794 "privatization type";
796 if (getDestroyRegion().empty())
800 "privatization",
"destroy",
811 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
817 if (getCombinerRegion().empty())
818 return emitOpError() <<
"expects non-empty combiner region";
820 Block &reductionBlock = getCombinerRegion().
front();
824 return emitOpError() <<
"expects combiner region with the first two "
825 <<
"arguments of the reduction type";
827 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
828 if (yieldOp.getOperands().size() != 1 ||
829 yieldOp.getOperands().getTypes()[0] !=
getType())
830 return emitOpError() <<
"expects combiner region to yield a value "
831 "of the reduction type";
847 if (parser.parseAttribute(attributes.emplace_back()) ||
848 parser.parseArrow() ||
849 parser.parseOperand(operands.emplace_back()) ||
850 parser.parseColonType(types.emplace_back()))
864 std::optional<mlir::ArrayAttr> attributes) {
865 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
866 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
867 << std::get<1>(it).getType();
876 template <
typename Op>
880 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
881 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
882 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
883 operand.getDefiningOp()))
885 "expect data entry/exit operation or acc.getdeviceptr "
890 template <
typename Op>
894 llvm::StringRef symbolName,
bool checkOperandType =
true) {
895 if (!operands.empty()) {
896 if (!attributes || attributes->size() != operands.size())
898 <<
"expected as many " << symbolName <<
" symbol reference as "
899 << operandName <<
" operands";
903 <<
"unexpected " << symbolName <<
" symbol reference";
908 for (
auto args : llvm::zip(operands, *attributes)) {
911 if (!set.insert(operand).second)
913 << operandName <<
" operand appears more than once";
916 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
917 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
920 <<
"expected symbol reference " << symbolRef <<
" to point to a "
921 << operandName <<
" declaration";
923 if (checkOperandType && decl.getType() && decl.getType() != varType)
924 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
925 <<
") to be the same type as " << operandName
926 <<
" declaration (" << decl.getType() <<
")";
932 unsigned ParallelOp::getNumDataOperands() {
933 return getReductionOperands().size() + getPrivateOperands().size() +
934 getFirstprivateOperands().size() + getDataClauseOperands().size();
937 Value ParallelOp::getDataOperand(
unsigned i) {
939 numOptional += getNumGangs().size();
940 numOptional += getNumWorkers().size();
941 numOptional += getVectorLength().size();
942 numOptional += getIfCond() ? 1 : 0;
943 numOptional += getSelfCond() ? 1 : 0;
944 return getOperand(getWaitOperands().size() + numOptional + i);
947 template <
typename Op>
949 ArrayAttr deviceTypes,
950 llvm::StringRef keyword) {
951 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
952 return op.
emitOpError() << keyword <<
" operands count must match "
953 << keyword <<
" device_type count";
957 template <
typename Op>
960 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
961 std::size_t numOperandsInSegments = 0;
962 std::size_t nbOfSegments = 0;
966 if (maxInSegment != 0 && segCount > maxInSegment)
967 return op.
emitOpError() << keyword <<
" expects a maximum of "
968 << maxInSegment <<
" values per segment";
969 numOperandsInSegments += segCount;
974 if ((numOperandsInSegments != operands.size()) ||
975 (!deviceTypes && !operands.empty()))
977 << keyword <<
" operand count does not match count in segments";
978 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
980 << keyword <<
" segment count does not match device_type count";
985 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
986 *
this, getPrivatizations(), getPrivateOperands(),
"private",
987 "privatizations",
false)))
989 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
990 *
this, getFirstprivatizations(), getFirstprivateOperands(),
991 "firstprivate",
"firstprivatizations",
false)))
993 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
994 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
995 "reductions",
false)))
999 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1000 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1004 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1005 getWaitOperandsDeviceTypeAttr(),
"wait")))
1009 getNumWorkersDeviceTypeAttr(),
1014 getVectorLengthDeviceTypeAttr(),
1019 getAsyncOperandsDeviceTypeAttr(),
1023 if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*
this)))
1026 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
1032 mlir::acc::DeviceType deviceType) {
1035 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1040 bool acc::ParallelOp::hasAsyncOnly() {
1044 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1052 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1057 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1062 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1067 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1072 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1074 getVectorLength(), deviceType);
1082 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1084 getNumGangsSegments(), deviceType);
1087 bool acc::ParallelOp::hasWaitOnly() {
1091 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1100 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1102 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1103 getHasWaitDevnum(), deviceType);
1110 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1112 getWaitOperandsSegments(), getHasWaitDevnum(),
1128 odsBuilder, odsState, asyncOperands,
nullptr,
1129 nullptr, waitOperands,
nullptr,
1131 nullptr, numGangs,
nullptr,
1132 nullptr, numWorkers,
1133 nullptr, vectorLength,
1134 nullptr, ifCond, selfCond,
1135 nullptr, reductionOperands,
nullptr,
1136 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
1137 nullptr, dataClauseOperands,
1153 int32_t crtOperandsSize = operands.size();
1156 if (parser.parseOperand(operands.emplace_back()) ||
1157 parser.parseColonType(types.emplace_back()))
1162 seg.push_back(operands.size() - crtOperandsSize);
1186 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1188 p <<
" [" << attr <<
"]";
1193 std::optional<mlir::ArrayAttr> deviceTypes,
1194 std::optional<mlir::DenseI32ArrayAttr> segments) {
1196 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1198 llvm::interleaveComma(
1199 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1200 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1220 int32_t crtOperandsSize = operands.size();
1224 if (parser.parseOperand(operands.emplace_back()) ||
1225 parser.parseColonType(types.emplace_back()))
1231 seg.push_back(operands.size() - crtOperandsSize);
1257 std::optional<mlir::DenseI32ArrayAttr> segments) {
1259 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1261 llvm::interleaveComma(
1262 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1263 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1276 mlir::ArrayAttr &keywordOnly) {
1280 bool needCommaBeforeOperands =
false;
1293 if (parser.parseAttribute(keywordAttrs.emplace_back()))
1300 needCommaBeforeOperands =
true;
1303 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1310 int32_t crtOperandsSize = operands.size();
1322 if (parser.parseOperand(operands.emplace_back()) ||
1323 parser.parseColonType(types.emplace_back()))
1329 seg.push_back(operands.size() - crtOperandsSize);
1358 if (attrs->size() != 1)
1360 if (
auto deviceTypeAttr =
1361 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1368 std::optional<mlir::ArrayAttr> deviceTypes,
1369 std::optional<mlir::DenseI32ArrayAttr> segments,
1370 std::optional<mlir::ArrayAttr> hasDevNum,
1371 std::optional<mlir::ArrayAttr> keywordOnly) {
1383 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1385 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1386 if (boolAttr && boolAttr.getValue())
1388 llvm::interleaveComma(
1389 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1390 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1406 if (parser.parseOperand(operands.emplace_back()) ||
1407 parser.parseColonType(types.emplace_back()))
1409 if (succeeded(parser.parseOptionalLSquare())) {
1410 if (parser.parseAttribute(attributes.emplace_back()) ||
1411 parser.parseRSquare())
1414 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1415 parser.getContext(), mlir::acc::DeviceType::None));
1429 std::optional<mlir::ArrayAttr> deviceTypes) {
1432 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
1433 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
1442 mlir::ArrayAttr &keywordOnlyDeviceType) {
1445 bool needCommaBeforeOperands =
false;
1451 keywordOnlyDeviceType =
1460 if (parser.parseAttribute(
1461 keywordOnlyDeviceTypeAttributes.emplace_back()))
1468 needCommaBeforeOperands =
true;
1471 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1476 if (parser.parseOperand(operands.emplace_back()) ||
1477 parser.parseColonType(types.emplace_back()))
1479 if (succeeded(parser.parseOptionalLSquare())) {
1480 if (parser.parseAttribute(attributes.emplace_back()) ||
1481 parser.parseRSquare())
1484 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1485 parser.getContext(), mlir::acc::DeviceType::None));
1491 if (failed(parser.parseRParen()))
1503 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1505 if (operands.begin() == operands.end() &&
1521 mlir::acc::CombinedConstructsTypeAttr &attr) {
1527 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1530 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1533 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1536 "expected compute construct name");
1547 mlir::acc::CombinedConstructsTypeAttr attr) {
1549 switch (attr.getValue()) {
1550 case mlir::acc::CombinedConstructsType::KernelsLoop:
1551 p <<
"combined(kernels)";
1553 case mlir::acc::CombinedConstructsType::ParallelLoop:
1554 p <<
"combined(parallel)";
1556 case mlir::acc::CombinedConstructsType::SerialLoop:
1557 p <<
"combined(serial)";
1567 unsigned SerialOp::getNumDataOperands() {
1568 return getReductionOperands().size() + getPrivateOperands().size() +
1569 getFirstprivateOperands().size() + getDataClauseOperands().size();
1572 Value SerialOp::getDataOperand(
unsigned i) {
1574 numOptional += getIfCond() ? 1 : 0;
1575 numOptional += getSelfCond() ? 1 : 0;
1576 return getOperand(getWaitOperands().size() + numOptional + i);
1579 bool acc::SerialOp::hasAsyncOnly() {
1583 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1591 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1596 bool acc::SerialOp::hasWaitOnly() {
1600 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1609 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1611 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1612 getHasWaitDevnum(), deviceType);
1619 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1621 getWaitOperandsSegments(), getHasWaitDevnum(),
1626 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1627 *
this, getPrivatizations(), getPrivateOperands(),
"private",
1628 "privatizations",
false)))
1630 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1631 *
this, getFirstprivatizations(), getFirstprivateOperands(),
1632 "firstprivate",
"firstprivatizations",
false)))
1634 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1635 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1636 "reductions",
false)))
1640 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1641 getWaitOperandsDeviceTypeAttr(),
"wait")))
1645 getAsyncOperandsDeviceTypeAttr(),
1649 if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*
this)))
1652 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
1659 unsigned KernelsOp::getNumDataOperands() {
1660 return getDataClauseOperands().size();
1663 Value KernelsOp::getDataOperand(
unsigned i) {
1665 numOptional += getWaitOperands().size();
1666 numOptional += getNumGangs().size();
1667 numOptional += getNumWorkers().size();
1668 numOptional += getVectorLength().size();
1669 numOptional += getIfCond() ? 1 : 0;
1670 numOptional += getSelfCond() ? 1 : 0;
1671 return getOperand(numOptional + i);
1674 bool acc::KernelsOp::hasAsyncOnly() {
1678 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1686 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1691 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1696 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1701 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1706 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1708 getVectorLength(), deviceType);
1716 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1718 getNumGangsSegments(), deviceType);
1721 bool acc::KernelsOp::hasWaitOnly() {
1725 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1734 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1736 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1737 getHasWaitDevnum(), deviceType);
1744 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1746 getWaitOperandsSegments(), getHasWaitDevnum(),
1752 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1753 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1757 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1758 getWaitOperandsDeviceTypeAttr(),
"wait")))
1762 getNumWorkersDeviceTypeAttr(),
1767 getVectorLengthDeviceTypeAttr(),
1772 getAsyncOperandsDeviceTypeAttr(),
1776 if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*
this)))
1779 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
1787 if (getDataClauseOperands().empty())
1788 return emitError(
"at least one operand must appear on the host_data "
1791 for (
mlir::Value operand : getDataClauseOperands())
1792 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1793 return emitError(
"expect data entry operation as defining op");
1799 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1811 bool &needCommaBetweenValues,
bool &newValue) {
1818 attributes.push_back(gangArgType);
1819 needCommaBetweenValues =
true;
1830 mlir::ArrayAttr &gangOnlyDeviceType) {
1835 bool needCommaBetweenValues =
false;
1836 bool needCommaBeforeOperands =
false;
1842 gangOnlyDeviceType =
1851 if (parser.parseAttribute(
1852 gangOnlyDeviceTypeAttributes.emplace_back()))
1859 needCommaBeforeOperands =
true;
1863 mlir::acc::GangArgType::Num);
1865 mlir::acc::GangArgType::Dim);
1867 parser.
getContext(), mlir::acc::GangArgType::Static);
1870 if (needCommaBeforeOperands) {
1871 needCommaBeforeOperands =
false;
1878 int32_t crtOperandsSize = gangOperands.size();
1880 bool newValue =
false;
1881 bool needValue =
false;
1882 if (needCommaBetweenValues) {
1890 gangOperands, gangOperandsType,
1891 gangArgTypeAttributes, argNum,
1892 needCommaBetweenValues, newValue)))
1895 gangOperands, gangOperandsType,
1896 gangArgTypeAttributes, argDim,
1897 needCommaBetweenValues, newValue)))
1899 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1900 gangOperands, gangOperandsType,
1901 gangArgTypeAttributes, argStatic,
1902 needCommaBetweenValues, newValue)))
1905 if (!newValue && needValue) {
1907 "new value expected after comma");
1915 if (gangOperands.empty())
1918 "expect at least one of num, dim or static values");
1924 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
1932 seg.push_back(gangOperands.size() - crtOperandsSize);
1940 gangArgTypeAttributes.end());
1945 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1954 std::optional<mlir::ArrayAttr> gangArgTypes,
1955 std::optional<mlir::ArrayAttr> deviceTypes,
1956 std::optional<mlir::DenseI32ArrayAttr> segments,
1957 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1959 if (operands.begin() == operands.end() &&
1974 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1976 llvm::interleaveComma(
1977 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1978 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
1979 (*gangArgTypes)[opIdx]);
1980 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
1981 p << LoopOp::getGangNumKeyword();
1982 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
1983 p << LoopOp::getGangDimKeyword();
1984 else if (gangArgTypeAttr.getValue() ==
1985 mlir::acc::GangArgType::Static)
1986 p << LoopOp::getGangStaticKeyword();
1987 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
1998 std::optional<mlir::ArrayAttr> segments,
1999 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2002 for (
auto attr : *segments) {
2003 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2004 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2012 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2015 for (
auto attr : deviceTypes) {
2016 auto deviceTypeAttr =
2017 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2018 if (!deviceTypeAttr)
2020 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2027 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2028 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2029 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
2030 <<
" as upperbound size";
2033 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2034 return emitOpError() <<
"collapse device_type attr must be define when"
2035 <<
" collapse attr is present";
2037 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2038 getCollapseAttr().getValue().size() !=
2039 getCollapseDeviceTypeAttr().getValue().size())
2040 return emitOpError() <<
"collapse attribute count must match collapse"
2041 <<
" device_type count";
2043 return emitOpError()
2044 <<
"duplicate device_type found in collapseDeviceType attribute";
2047 if (!getGangOperands().empty()) {
2048 if (!getGangOperandsArgType())
2049 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
2050 <<
" when gang operands are present";
2052 if (getGangOperands().size() !=
2053 getGangOperandsArgTypeAttr().getValue().size())
2054 return emitOpError() <<
"gangOperandsArgType attribute count must match"
2055 <<
" gangOperands count";
2058 return emitOpError() <<
"duplicate device_type found in gang attribute";
2061 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
2062 getGangOperandsDeviceTypeAttr(),
"gang")))
2067 return emitOpError() <<
"duplicate device_type found in worker attribute";
2069 return emitOpError() <<
"duplicate device_type found in "
2070 "workerNumOperandsDeviceType attribute";
2072 getWorkerNumOperandsDeviceTypeAttr(),
2078 return emitOpError() <<
"duplicate device_type found in vector attribute";
2080 return emitOpError() <<
"duplicate device_type found in "
2081 "vectorOperandsDeviceType attribute";
2083 getVectorOperandsDeviceTypeAttr(),
2088 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
2089 getTileOperandsDeviceTypeAttr(),
"tile")))
2093 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2097 return emitError() <<
"only one of \"" << acc::LoopOp::getAutoAttrStrName()
2098 <<
"\", " << getIndependentAttrName() <<
", "
2100 <<
" can be present at the same time";
2105 for (
auto attr : getSeqAttr()) {
2106 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2107 if (hasVector(deviceTypeAttr.getValue()) ||
2108 getVectorValue(deviceTypeAttr.getValue()) ||
2109 hasWorker(deviceTypeAttr.getValue()) ||
2110 getWorkerValue(deviceTypeAttr.getValue()) ||
2111 hasGang(deviceTypeAttr.getValue()) ||
2112 getGangValue(mlir::acc::GangArgType::Num,
2113 deviceTypeAttr.getValue()) ||
2114 getGangValue(mlir::acc::GangArgType::Dim,
2115 deviceTypeAttr.getValue()) ||
2116 getGangValue(mlir::acc::GangArgType::Static,
2117 deviceTypeAttr.getValue()))
2119 <<
"gang, worker or vector cannot appear with the seq attr";
2123 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2124 *
this, getPrivatizations(), getPrivateOperands(),
"private",
2125 "privatizations",
false)))
2128 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2129 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
2130 "reductions",
false)))
2133 if (getCombined().has_value() &&
2134 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2135 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2136 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2137 return emitError(
"unexpected combined constructs attribute");
2141 if (getRegion().empty())
2142 return emitError(
"expected non-empty body.");
2147 unsigned LoopOp::getNumDataOperands() {
2148 return getReductionOperands().size() + getPrivateOperands().size();
2151 Value LoopOp::getDataOperand(
unsigned i) {
2152 unsigned numOptional =
2153 getLowerbound().size() + getUpperbound().size() + getStep().size();
2154 numOptional += getGangOperands().size();
2155 numOptional += getVectorOperands().size();
2156 numOptional += getWorkerNumOperands().size();
2157 numOptional += getTileOperands().size();
2158 numOptional += getCacheOperands().size();
2159 return getOperand(numOptional + i);
2164 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2168 bool LoopOp::hasIndependent() {
2172 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2178 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2186 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2188 getVectorOperands(), deviceType);
2193 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2201 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2203 getWorkerNumOperands(), deviceType);
2208 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2217 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2219 getTileOperandsSegments(), deviceType);
2222 std::optional<int64_t> LoopOp::getCollapseValue() {
2226 std::optional<int64_t>
2227 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2228 if (!getCollapseAttr())
2229 return std::nullopt;
2230 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2232 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2233 return intAttr.getValue().getZExtValue();
2235 return std::nullopt;
2238 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2242 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2243 mlir::acc::DeviceType deviceType) {
2244 if (getGangOperands().empty())
2246 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
2247 int32_t nbOperandsBefore = 0;
2248 for (
unsigned i = 0; i < *pos; ++i)
2249 nbOperandsBefore += (*getGangOperandsSegments())[i];
2252 .drop_front(nbOperandsBefore)
2253 .take_front((*getGangOperandsSegments())[*pos]);
2255 int32_t argTypeIdx = nbOperandsBefore;
2256 for (
auto value : values) {
2257 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2258 (*getGangOperandsArgType())[argTypeIdx]);
2259 if (gangArgTypeAttr.getValue() == gangArgType)
2269 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2274 return {&getRegion()};
2318 if (!regionArgs.empty()) {
2319 p << acc::LoopOp::getControlKeyword() <<
"(";
2320 llvm::interleaveComma(regionArgs, p,
2322 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
2323 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
2324 <<
" : " << stepType <<
") ";
2337 if (getOperands().empty() && !getDefaultAttr())
2338 return emitError(
"at least one operand or the default attribute "
2339 "must appear on the data operation");
2341 for (
mlir::Value operand : getDataClauseOperands())
2342 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2343 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2344 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2345 operand.getDefiningOp()))
2346 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2349 if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*
this)))
2355 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
2357 Value DataOp::getDataOperand(
unsigned i) {
2358 unsigned numOptional = getIfCond() ? 1 : 0;
2360 numOptional += getWaitOperands().size();
2361 return getOperand(numOptional + i);
2364 bool acc::DataOp::hasAsyncOnly() {
2368 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2376 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2383 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2392 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2394 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2395 getHasWaitDevnum(), deviceType);
2402 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2404 getWaitOperandsSegments(), getHasWaitDevnum(),
2416 if (getDataClauseOperands().empty())
2417 return emitError(
"at least one operand must be present in dataOperands on "
2418 "the exit data operation");
2422 if (getAsyncOperand() && getAsync())
2423 return emitError(
"async attribute cannot appear with asyncOperand");
2427 if (!getWaitOperands().empty() && getWait())
2428 return emitError(
"wait attribute cannot appear with waitOperands");
2430 if (getWaitDevnum() && getWaitOperands().empty())
2431 return emitError(
"wait_devnum cannot appear without waitOperands");
2436 unsigned ExitDataOp::getNumDataOperands() {
2437 return getDataClauseOperands().size();
2440 Value ExitDataOp::getDataOperand(
unsigned i) {
2441 unsigned numOptional = getIfCond() ? 1 : 0;
2442 numOptional += getAsyncOperand() ? 1 : 0;
2443 numOptional += getWaitDevnum() ? 1 : 0;
2444 return getOperand(getWaitOperands().size() + numOptional + i);
2449 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
2460 if (getDataClauseOperands().empty())
2461 return emitError(
"at least one operand must be present in dataOperands on "
2462 "the enter data operation");
2466 if (getAsyncOperand() && getAsync())
2467 return emitError(
"async attribute cannot appear with asyncOperand");
2471 if (!getWaitOperands().empty() && getWait())
2472 return emitError(
"wait attribute cannot appear with waitOperands");
2474 if (getWaitDevnum() && getWaitOperands().empty())
2475 return emitError(
"wait_devnum cannot appear without waitOperands");
2477 for (
mlir::Value operand : getDataClauseOperands())
2478 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2479 operand.getDefiningOp()))
2480 return emitError(
"expect data entry operation as defining op");
2485 unsigned EnterDataOp::getNumDataOperands() {
2486 return getDataClauseOperands().size();
2489 Value EnterDataOp::getDataOperand(
unsigned i) {
2490 unsigned numOptional = getIfCond() ? 1 : 0;
2491 numOptional += getAsyncOperand() ? 1 : 0;
2492 numOptional += getWaitDevnum() ? 1 : 0;
2493 return getOperand(getWaitOperands().size() + numOptional + i);
2498 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
2517 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2524 if (
Value writeVal = op.getWriteOpVal()) {
2534 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
2540 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2541 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2543 return dyn_cast<AtomicReadOp>(getSecondOp());
2546 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2547 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2549 return dyn_cast<AtomicWriteOp>(getSecondOp());
2552 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2553 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2555 return dyn_cast<AtomicUpdateOp>(getSecondOp());
2558 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
2564 template <
typename Op>
2565 static LogicalResult
2567 bool requireAtLeastOneOperand =
true) {
2568 if (operands.empty() && requireAtLeastOneOperand)
2571 "at least one operand must appear on the declare operation");
2574 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2575 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2576 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2577 operand.getDefiningOp()))
2579 "expect valid declare data entry operation or acc.getdeviceptr "
2583 assert(varPtr &&
"declare operands can only be data entry operations which "
2584 "must have varPtr");
2585 std::optional<mlir::acc::DataClause> dataClauseOptional{
2587 assert(dataClauseOptional.has_value() &&
2588 "declare operands can only be data entry operations which must have "
2592 if (!varPtr.getDefiningOp())
2596 auto declareAttribute{
2598 if (!declareAttribute)
2600 "expect declare attribute on variable in declare operation");
2602 auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2603 if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2605 "expect matching declare attribute on variable in declare operation");
2612 if (declAttr.getImplicit() &&
2615 "implicitness must match between declare op and flag on variable");
2649 acc::DeviceType dtype) {
2650 unsigned parallelism = 0;
2651 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2652 parallelism += op.hasWorker(dtype) ? 1 : 0;
2653 parallelism += op.hasVector(dtype) ? 1 : 0;
2654 parallelism += op.hasSeq(dtype) ? 1 : 0;
2659 unsigned baseParallelism =
2662 if (baseParallelism > 1)
2663 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
2664 "be present at the same time";
2666 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2668 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
2673 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2674 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
2675 "be present at the same time";
2682 mlir::ArrayAttr &deviceTypes) {
2687 if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2689 if (failed(parser.parseOptionalLSquare())) {
2690 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2691 parser.getContext(), mlir::acc::DeviceType::None));
2693 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2694 parser.parseRSquare())
2702 deviceTypes =
ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2708 std::optional<mlir::ArrayAttr> bindName,
2709 std::optional<mlir::ArrayAttr> deviceTypes) {
2710 llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2711 [&](
const auto &pair) {
2712 p << std::get<0>(pair);
2718 mlir::ArrayAttr &gang,
2719 mlir::ArrayAttr &gangDim,
2720 mlir::ArrayAttr &gangDimDeviceTypes) {
2723 gangDimDeviceTypeAttrs;
2724 bool needCommaBeforeOperands =
false;
2737 if (parser.parseAttribute(gangAttrs.emplace_back()))
2744 needCommaBeforeOperands =
true;
2747 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2751 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2752 parser.parseColon() ||
2753 parser.parseAttribute(gangDimAttrs.emplace_back()))
2755 if (succeeded(parser.parseOptionalLSquare())) {
2756 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2757 parser.parseRSquare())
2760 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2761 parser.getContext(), mlir::acc::DeviceType::None));
2767 if (failed(parser.parseRParen()))
2772 gangDimDeviceTypes =
2779 std::optional<mlir::ArrayAttr> gang,
2780 std::optional<mlir::ArrayAttr> gangDim,
2781 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2784 gang->size() == 1) {
2785 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2798 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2799 [&](
const auto &pair) {
2800 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
2801 p << std::get<0>(pair);
2809 mlir::ArrayAttr &deviceTypes) {
2822 if (parser.parseAttribute(attributes.emplace_back()))
2836 std::optional<mlir::ArrayAttr> deviceTypes) {
2839 auto deviceTypeAttr =
2840 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2850 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2858 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2864 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2870 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2874 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2878 std::optional<llvm::StringRef>
2879 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2881 return std::nullopt;
2882 if (
auto pos =
findSegment(*getBindNameDeviceType(), deviceType)) {
2883 auto attr = (*getBindName())[*pos];
2884 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2885 return stringAttr.getValue();
2887 return std::nullopt;
2892 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2896 std::optional<int64_t> RoutineOp::getGangDimValue() {
2900 std::optional<int64_t>
2901 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2903 return std::nullopt;
2904 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
2905 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2906 return intAttr.getInt();
2908 return std::nullopt;
2919 return emitOpError(
"cannot be nested in a compute operation");
2931 return emitOpError(
"cannot be nested in a compute operation");
2943 return emitOpError(
"cannot be nested in a compute operation");
2944 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2945 return emitOpError(
"at least one default_async, device_num, or device_type "
2946 "operand must appear");
2956 if (getDataClauseOperands().empty())
2957 return emitError(
"at least one value must be present in dataOperands");
2960 getAsyncOperandsDeviceTypeAttr(),
2965 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2966 getWaitOperandsDeviceTypeAttr(),
"wait")))
2969 if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*
this)))
2972 for (
mlir::Value operand : getDataClauseOperands())
2973 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
2974 operand.getDefiningOp()))
2975 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2981 unsigned UpdateOp::getNumDataOperands() {
2982 return getDataClauseOperands().size();
2985 Value UpdateOp::getDataOperand(
unsigned i) {
2987 numOptional += getIfCond() ? 1 : 0;
2988 return getOperand(getWaitOperands().size() + numOptional + i);
2993 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
2996 bool UpdateOp::hasAsyncOnly() {
3000 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3008 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3018 bool UpdateOp::hasWaitOnly() {
3022 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3031 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3033 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3034 getHasWaitDevnum(), deviceType);
3041 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3043 getWaitOperandsSegments(), getHasWaitDevnum(),
3054 if (getAsyncOperand() && getAsync())
3055 return emitError(
"async attribute cannot appear with asyncOperand");
3057 if (getWaitDevnum() && getWaitOperands().empty())
3058 return emitError(
"wait_devnum cannot appear without waitOperands");
3063 #define GET_OP_CLASSES
3064 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3066 #define GET_ATTRDEF_CLASSES
3067 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3069 #define GET_TYPEDEF_CLASSES
3070 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3081 .Case<ACC_DATA_ENTRY_OPS>(
3082 [&](
auto entry) {
return entry.getVarPtr(); })
3083 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3084 [&](
auto exit) {
return exit.getVarPtr(); })
3102 [&](
auto entry) {
return entry.getVarType(); })
3103 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3104 [&](
auto exit) {
return exit.getVarType(); })
3114 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3115 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
3125 [&](
auto dataClause) {
return dataClause.getAccVar(); })
3134 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
3144 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
3146 dataClause.getBounds().begin(), dataClause.getBounds().end());
3158 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
3160 dataClause.getAsyncOperands().begin(),
3161 dataClause.getAsyncOperands().end());
3172 return dataClause.getAsyncOperandsDeviceTypeAttr();
3180 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
3187 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
3194 std::optional<mlir::acc::DataClause>
3199 .Case<ACC_DATA_ENTRY_OPS>(
3200 [&](
auto entry) {
return entry.getDataClause(); })
3208 [&](
auto entry) {
return entry.getImplicit(); })
3217 [&](
auto entry) {
return entry.getDataClauseOperands(); })
3219 return dataOperands;
3227 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
3229 return dataOperands;
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.