21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
28 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
36 static bool isScalarLikeType(
Type type) {
40 struct MemRefPointerLikeModel
41 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
44 return cast<MemRefType>(pointer).getElementType();
46 mlir::acc::VariableTypeCategory
49 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
50 return mappableTy.getTypeCategory(varPtr);
52 auto memrefTy = cast<MemRefType>(pointer);
53 if (!memrefTy.hasRank()) {
56 return mlir::acc::VariableTypeCategory::uncategorized;
59 if (memrefTy.getRank() == 0) {
60 if (isScalarLikeType(memrefTy.getElementType())) {
61 return mlir::acc::VariableTypeCategory::scalar;
65 return mlir::acc::VariableTypeCategory::uncategorized;
69 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
70 return mlir::acc::VariableTypeCategory::array;
74 struct LLVMPointerPointerLikeModel
75 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
76 LLVM::LLVMPointerType> {
85 void OpenACCDialect::initialize() {
88 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
91 #define GET_ATTRDEF_LIST
92 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
95 #define GET_TYPEDEF_LIST
96 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
102 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
103 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
112 if (arrayAttr && *arrayAttr && arrayAttr->size() > 0)
118 mlir::acc::DeviceType deviceType) {
122 for (
auto attr : *arrayAttr) {
123 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
124 if (deviceTypeAttr.getValue() == deviceType)
132 std::optional<mlir::ArrayAttr> deviceTypes) {
137 llvm::interleaveComma(*deviceTypes, p,
143 mlir::acc::DeviceType deviceType) {
144 unsigned segmentIdx = 0;
145 for (
auto attr : segments) {
146 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
147 if (deviceTypeAttr.getValue() == deviceType)
148 return std::make_optional(segmentIdx);
158 mlir::acc::DeviceType deviceType) {
160 return range.take_front(0);
161 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
162 int32_t nbOperandsBefore = 0;
163 for (
unsigned i = 0; i < *pos; ++i)
164 nbOperandsBefore += (*segments)[i];
165 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
167 return range.take_front(0);
174 std::optional<mlir::ArrayAttr> hasWaitDevnum,
175 mlir::acc::DeviceType deviceType) {
178 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
179 if (hasWaitDevnum->getValue()[*pos])
190 std::optional<mlir::ArrayAttr> hasWaitDevnum,
191 mlir::acc::DeviceType deviceType) {
196 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
197 if (hasWaitDevnum && *hasWaitDevnum) {
198 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
199 if (boolAttr.getValue())
200 return range.drop_front(1);
206 template <
typename Op>
208 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
210 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
215 op.hasAsyncOnly(dtype))
216 return op.
emitError(
"async attribute cannot appear with asyncOperand");
221 op.hasWaitOnly(dtype))
222 return op.
emitError(
"wait attribute cannot appear with waitOperands");
227 template <
typename Op>
230 return op.
emitError(
"must have var operand");
232 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
233 mlir::isa<mlir::acc::MappableType>(op.getVar().getType())) {
238 return op.
emitError(
"var must be mappable or pointer-like (not both)");
241 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
242 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
243 return op.
emitError(
"var must be mappable or pointer-like");
245 if (mlir::isa<mlir::acc::MappableType>(op.getVar().getType()) &&
246 op.getVarType() != op.getVar().getType())
247 return op.
emitError(
"varType must match when var is mappable");
252 template <
typename Op>
254 if (op.getVar().getType() != op.getAccVar().getType())
255 return op.
emitError(
"input and output types must match");
277 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
298 if (failed(parser.
parseType(accVarType)))
308 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
320 mlir::TypeAttr &varTypeAttr) {
321 if (failed(parser.
parseType(varPtrType)))
337 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
339 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
348 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
356 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
357 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
359 if (typeToCheckAgainst != varType) {
370 auto extent = getExtent();
371 auto upperbound = getUpperbound();
372 if (!extent && !upperbound)
373 return emitError(
"expected extent or upperbound.");
383 "data clause associated with private operation must match its intent");
394 return emitError(
"data clause associated with firstprivate operation must "
406 return emitError(
"data clause associated with reduction operation must "
418 return emitError(
"data clause associated with deviceptr operation must "
433 "data clause associated with present operation must match its intent");
446 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
451 "data clause associated with copyin operation must match its intent"
452 " or specify original clause this operation was decomposed from");
460 bool acc::CopyinOp::isCopyinReadonly() {
461 return getDataClause() == acc::DataClause::acc_copyin_readonly;
474 "data clause associated with create operation must match its intent"
475 " or specify original clause this operation was decomposed from");
483 bool acc::CreateOp::isCreateZero() {
485 return getDataClause() == acc::DataClause::acc_create_zero ||
494 return emitError(
"data clause associated with no_create operation must "
509 "data clause associated with attach operation must match its intent");
522 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
523 return emitError(
"data clause associated with device_resident operation "
524 "must match its intent");
539 "data clause associated with link operation must match its intent");
557 "data clause associated with copyout operation must match its intent"
558 " or specify original clause this operation was decomposed from");
560 return emitError(
"must have both host and device pointers");
568 bool acc::CopyoutOp::isCopyoutZero() {
584 getDataClause() != acc::DataClause::acc_declare_device_resident &&
587 "data clause associated with delete operation must match its intent"
588 " or specify original clause this operation was decomposed from");
590 return emitError(
"must have device pointer");
602 "data clause associated with detach operation must match its intent"
603 " or specify original clause this operation was decomposed from");
605 return emitError(
"must have device pointer");
617 "data clause associated with host operation must match its intent"
618 " or specify original clause this operation was decomposed from");
620 return emitError(
"must have both host and device pointers");
635 "data clause associated with device operation must match its intent"
636 " or specify original clause this operation was decomposed from");
651 "data clause associated with use_device operation must match its intent"
652 " or specify original clause this operation was decomposed from");
668 "data clause associated with cache operation must match its intent"
669 " or specify original clause this operation was decomposed from");
677 template <
typename StructureOp>
679 unsigned nRegions = 1) {
682 for (
unsigned i = 0; i < nRegions; ++i)
683 regions.push_back(state.addRegion());
685 for (
Region *region : regions)
693 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
700 template <
typename OpTy>
704 LogicalResult matchAndRewrite(OpTy op,
707 Value ifCond = op.getIfCond();
711 IntegerAttr constAttr;
714 if (constAttr.getInt())
715 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
727 assert(llvm::hasSingleElement(region) &&
"expected single-region block");
739 template <
typename OpTy>
740 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
743 LogicalResult matchAndRewrite(OpTy op,
746 Value ifCond = op.getIfCond();
750 IntegerAttr constAttr;
753 if (constAttr.getInt())
754 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
769 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
771 if (optional && region.
empty())
775 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
779 return op->
emitOpError() <<
"expects " << regionName
782 << regionType <<
" type";
785 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
786 if (yieldOp.getOperands().size() != 1 ||
787 yieldOp.getOperands().getTypes()[0] != type)
788 return op->
emitOpError() <<
"expects " << regionName
790 "yield a value of the "
791 << regionType <<
" type";
797 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
799 "privatization",
"init",
getType(),
803 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
813 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
815 "privatization",
"init",
getType(),
819 if (getCopyRegion().empty())
820 return emitOpError() <<
"expects non-empty copy region";
825 return emitOpError() <<
"expects copy region with two arguments of the "
826 "privatization type";
828 if (getDestroyRegion().empty())
832 "privatization",
"destroy",
843 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
849 if (getCombinerRegion().empty())
850 return emitOpError() <<
"expects non-empty combiner region";
852 Block &reductionBlock = getCombinerRegion().
front();
856 return emitOpError() <<
"expects combiner region with the first two "
857 <<
"arguments of the reduction type";
859 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
860 if (yieldOp.getOperands().size() != 1 ||
861 yieldOp.getOperands().getTypes()[0] !=
getType())
862 return emitOpError() <<
"expects combiner region to yield a value "
863 "of the reduction type";
879 if (parser.parseAttribute(attributes.emplace_back()) ||
880 parser.parseArrow() ||
881 parser.parseOperand(operands.emplace_back()) ||
882 parser.parseColonType(types.emplace_back()))
896 std::optional<mlir::ArrayAttr> attributes) {
897 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
898 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
899 << std::get<1>(it).getType();
908 template <
typename Op>
912 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
913 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
914 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
915 operand.getDefiningOp()))
917 "expect data entry/exit operation or acc.getdeviceptr "
922 template <
typename Op>
926 llvm::StringRef symbolName,
bool checkOperandType =
true) {
927 if (!operands.empty()) {
928 if (!attributes || attributes->size() != operands.size())
930 <<
"expected as many " << symbolName <<
" symbol reference as "
931 << operandName <<
" operands";
935 <<
"unexpected " << symbolName <<
" symbol reference";
940 for (
auto args : llvm::zip(operands, *attributes)) {
943 if (!set.insert(operand).second)
945 << operandName <<
" operand appears more than once";
948 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
949 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
952 <<
"expected symbol reference " << symbolRef <<
" to point to a "
953 << operandName <<
" declaration";
955 if (checkOperandType && decl.getType() && decl.getType() != varType)
956 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
957 <<
") to be the same type as " << operandName
958 <<
" declaration (" << decl.getType() <<
")";
964 unsigned ParallelOp::getNumDataOperands() {
965 return getReductionOperands().size() + getPrivateOperands().size() +
966 getFirstprivateOperands().size() + getDataClauseOperands().size();
969 Value ParallelOp::getDataOperand(
unsigned i) {
971 numOptional += getNumGangs().size();
972 numOptional += getNumWorkers().size();
973 numOptional += getVectorLength().size();
974 numOptional += getIfCond() ? 1 : 0;
975 numOptional += getSelfCond() ? 1 : 0;
976 return getOperand(getWaitOperands().size() + numOptional + i);
979 template <
typename Op>
981 ArrayAttr deviceTypes,
982 llvm::StringRef keyword) {
983 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
984 return op.
emitOpError() << keyword <<
" operands count must match "
985 << keyword <<
" device_type count";
989 template <
typename Op>
992 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
993 std::size_t numOperandsInSegments = 0;
994 std::size_t nbOfSegments = 0;
998 if (maxInSegment != 0 && segCount > maxInSegment)
999 return op.
emitOpError() << keyword <<
" expects a maximum of "
1000 << maxInSegment <<
" values per segment";
1001 numOperandsInSegments += segCount;
1006 if ((numOperandsInSegments != operands.size()) ||
1007 (!deviceTypes && !operands.empty()))
1009 << keyword <<
" operand count does not match count in segments";
1010 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1012 << keyword <<
" segment count does not match device_type count";
1017 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1018 *
this, getPrivatizations(), getPrivateOperands(),
"private",
1019 "privatizations",
false)))
1021 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1022 *
this, getFirstprivatizations(), getFirstprivateOperands(),
1023 "firstprivate",
"firstprivatizations",
false)))
1025 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1026 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1027 "reductions",
false)))
1031 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1032 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1036 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1037 getWaitOperandsDeviceTypeAttr(),
"wait")))
1041 getNumWorkersDeviceTypeAttr(),
1046 getVectorLengthDeviceTypeAttr(),
1051 getAsyncOperandsDeviceTypeAttr(),
1055 if (failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*
this)))
1058 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
1064 mlir::acc::DeviceType deviceType) {
1067 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1072 bool acc::ParallelOp::hasAsyncOnly() {
1076 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1084 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1089 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1094 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1099 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1104 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1106 getVectorLength(), deviceType);
1114 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1116 getNumGangsSegments(), deviceType);
1119 bool acc::ParallelOp::hasWaitOnly() {
1123 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1132 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1134 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1135 getHasWaitDevnum(), deviceType);
1142 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1144 getWaitOperandsSegments(), getHasWaitDevnum(),
1160 odsBuilder, odsState, asyncOperands,
nullptr,
1161 nullptr, waitOperands,
nullptr,
1163 nullptr, numGangs,
nullptr,
1164 nullptr, numWorkers,
1165 nullptr, vectorLength,
1166 nullptr, ifCond, selfCond,
1167 nullptr, reductionOperands,
nullptr,
1168 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
1169 nullptr, dataClauseOperands,
1185 int32_t crtOperandsSize = operands.size();
1188 if (parser.parseOperand(operands.emplace_back()) ||
1189 parser.parseColonType(types.emplace_back()))
1194 seg.push_back(operands.size() - crtOperandsSize);
1218 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1220 p <<
" [" << attr <<
"]";
1225 std::optional<mlir::ArrayAttr> deviceTypes,
1226 std::optional<mlir::DenseI32ArrayAttr> segments) {
1228 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1230 llvm::interleaveComma(
1231 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1232 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1252 int32_t crtOperandsSize = operands.size();
1256 if (parser.parseOperand(operands.emplace_back()) ||
1257 parser.parseColonType(types.emplace_back()))
1263 seg.push_back(operands.size() - crtOperandsSize);
1289 std::optional<mlir::DenseI32ArrayAttr> segments) {
1291 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1293 llvm::interleaveComma(
1294 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1295 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1308 mlir::ArrayAttr &keywordOnly) {
1312 bool needCommaBeforeOperands =
false;
1325 if (parser.parseAttribute(keywordAttrs.emplace_back()))
1332 needCommaBeforeOperands =
true;
1335 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1342 int32_t crtOperandsSize = operands.size();
1354 if (parser.parseOperand(operands.emplace_back()) ||
1355 parser.parseColonType(types.emplace_back()))
1361 seg.push_back(operands.size() - crtOperandsSize);
1390 if (attrs->size() != 1)
1392 if (
auto deviceTypeAttr =
1393 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1400 std::optional<mlir::ArrayAttr> deviceTypes,
1401 std::optional<mlir::DenseI32ArrayAttr> segments,
1402 std::optional<mlir::ArrayAttr> hasDevNum,
1403 std::optional<mlir::ArrayAttr> keywordOnly) {
1415 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1417 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1418 if (boolAttr && boolAttr.getValue())
1420 llvm::interleaveComma(
1421 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1422 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1438 if (parser.parseOperand(operands.emplace_back()) ||
1439 parser.parseColonType(types.emplace_back()))
1441 if (succeeded(parser.parseOptionalLSquare())) {
1442 if (parser.parseAttribute(attributes.emplace_back()) ||
1443 parser.parseRSquare())
1446 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1447 parser.getContext(), mlir::acc::DeviceType::None));
1461 std::optional<mlir::ArrayAttr> deviceTypes) {
1464 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
1465 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
1474 mlir::ArrayAttr &keywordOnlyDeviceType) {
1477 bool needCommaBeforeOperands =
false;
1483 keywordOnlyDeviceType =
1492 if (parser.parseAttribute(
1493 keywordOnlyDeviceTypeAttributes.emplace_back()))
1500 needCommaBeforeOperands =
true;
1503 if (needCommaBeforeOperands && failed(parser.
parseComma()))
1508 if (parser.parseOperand(operands.emplace_back()) ||
1509 parser.parseColonType(types.emplace_back()))
1511 if (succeeded(parser.parseOptionalLSquare())) {
1512 if (parser.parseAttribute(attributes.emplace_back()) ||
1513 parser.parseRSquare())
1516 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1517 parser.getContext(), mlir::acc::DeviceType::None));
1523 if (failed(parser.parseRParen()))
1535 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1537 if (operands.begin() == operands.end() &&
1553 mlir::acc::CombinedConstructsTypeAttr &attr) {
1559 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1562 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1565 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1568 "expected compute construct name");
1579 mlir::acc::CombinedConstructsTypeAttr attr) {
1581 switch (attr.getValue()) {
1582 case mlir::acc::CombinedConstructsType::KernelsLoop:
1583 p <<
"combined(kernels)";
1585 case mlir::acc::CombinedConstructsType::ParallelLoop:
1586 p <<
"combined(parallel)";
1588 case mlir::acc::CombinedConstructsType::SerialLoop:
1589 p <<
"combined(serial)";
1599 unsigned SerialOp::getNumDataOperands() {
1600 return getReductionOperands().size() + getPrivateOperands().size() +
1601 getFirstprivateOperands().size() + getDataClauseOperands().size();
1604 Value SerialOp::getDataOperand(
unsigned i) {
1606 numOptional += getIfCond() ? 1 : 0;
1607 numOptional += getSelfCond() ? 1 : 0;
1608 return getOperand(getWaitOperands().size() + numOptional + i);
1611 bool acc::SerialOp::hasAsyncOnly() {
1615 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1623 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1628 bool acc::SerialOp::hasWaitOnly() {
1632 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1641 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1643 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1644 getHasWaitDevnum(), deviceType);
1651 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1653 getWaitOperandsSegments(), getHasWaitDevnum(),
1658 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1659 *
this, getPrivatizations(), getPrivateOperands(),
"private",
1660 "privatizations",
false)))
1662 if (failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1663 *
this, getFirstprivatizations(), getFirstprivateOperands(),
1664 "firstprivate",
"firstprivatizations",
false)))
1666 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1667 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1668 "reductions",
false)))
1672 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1673 getWaitOperandsDeviceTypeAttr(),
"wait")))
1677 getAsyncOperandsDeviceTypeAttr(),
1681 if (failed(checkWaitAndAsyncConflict<acc::SerialOp>(*
this)))
1684 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
1691 unsigned KernelsOp::getNumDataOperands() {
1692 return getDataClauseOperands().size();
1695 Value KernelsOp::getDataOperand(
unsigned i) {
1697 numOptional += getWaitOperands().size();
1698 numOptional += getNumGangs().size();
1699 numOptional += getNumWorkers().size();
1700 numOptional += getVectorLength().size();
1701 numOptional += getIfCond() ? 1 : 0;
1702 numOptional += getSelfCond() ? 1 : 0;
1703 return getOperand(numOptional + i);
1706 bool acc::KernelsOp::hasAsyncOnly() {
1710 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1718 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1723 mlir::Value acc::KernelsOp::getNumWorkersValue() {
1728 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1733 mlir::Value acc::KernelsOp::getVectorLengthValue() {
1738 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1740 getVectorLength(), deviceType);
1748 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1750 getNumGangsSegments(), deviceType);
1753 bool acc::KernelsOp::hasWaitOnly() {
1757 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1766 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1768 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1769 getHasWaitDevnum(), deviceType);
1776 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1778 getWaitOperandsSegments(), getHasWaitDevnum(),
1784 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1785 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1789 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1790 getWaitOperandsDeviceTypeAttr(),
"wait")))
1794 getNumWorkersDeviceTypeAttr(),
1799 getVectorLengthDeviceTypeAttr(),
1804 getAsyncOperandsDeviceTypeAttr(),
1808 if (failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*
this)))
1811 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
1819 if (getDataClauseOperands().empty())
1820 return emitError(
"at least one operand must appear on the host_data "
1823 for (
mlir::Value operand : getDataClauseOperands())
1824 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
1825 return emitError(
"expect data entry operation as defining op");
1831 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
1843 bool &needCommaBetweenValues,
bool &newValue) {
1850 attributes.push_back(gangArgType);
1851 needCommaBetweenValues =
true;
1862 mlir::ArrayAttr &gangOnlyDeviceType) {
1867 bool needCommaBetweenValues =
false;
1868 bool needCommaBeforeOperands =
false;
1874 gangOnlyDeviceType =
1883 if (parser.parseAttribute(
1884 gangOnlyDeviceTypeAttributes.emplace_back()))
1891 needCommaBeforeOperands =
true;
1895 mlir::acc::GangArgType::Num);
1897 mlir::acc::GangArgType::Dim);
1899 parser.
getContext(), mlir::acc::GangArgType::Static);
1902 if (needCommaBeforeOperands) {
1903 needCommaBeforeOperands =
false;
1910 int32_t crtOperandsSize = gangOperands.size();
1912 bool newValue =
false;
1913 bool needValue =
false;
1914 if (needCommaBetweenValues) {
1922 gangOperands, gangOperandsType,
1923 gangArgTypeAttributes, argNum,
1924 needCommaBetweenValues, newValue)))
1927 gangOperands, gangOperandsType,
1928 gangArgTypeAttributes, argDim,
1929 needCommaBetweenValues, newValue)))
1931 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
1932 gangOperands, gangOperandsType,
1933 gangArgTypeAttributes, argStatic,
1934 needCommaBetweenValues, newValue)))
1937 if (!newValue && needValue) {
1939 "new value expected after comma");
1947 if (gangOperands.empty())
1950 "expect at least one of num, dim or static values");
1956 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
1964 seg.push_back(gangOperands.size() - crtOperandsSize);
1972 gangArgTypeAttributes.end());
1977 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
1986 std::optional<mlir::ArrayAttr> gangArgTypes,
1987 std::optional<mlir::ArrayAttr> deviceTypes,
1988 std::optional<mlir::DenseI32ArrayAttr> segments,
1989 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
1991 if (operands.begin() == operands.end() &&
2006 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2008 llvm::interleaveComma(
2009 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2010 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2011 (*gangArgTypes)[opIdx]);
2012 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2013 p << LoopOp::getGangNumKeyword();
2014 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2015 p << LoopOp::getGangDimKeyword();
2016 else if (gangArgTypeAttr.getValue() ==
2017 mlir::acc::GangArgType::Static)
2018 p << LoopOp::getGangStaticKeyword();
2019 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
2030 std::optional<mlir::ArrayAttr> segments,
2031 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2034 for (
auto attr : *segments) {
2035 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2036 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2044 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2047 for (
auto attr : deviceTypes) {
2048 auto deviceTypeAttr =
2049 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2050 if (!deviceTypeAttr)
2052 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2059 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2060 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2061 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
2062 <<
" as upperbound size";
2065 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2066 return emitOpError() <<
"collapse device_type attr must be define when"
2067 <<
" collapse attr is present";
2069 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2070 getCollapseAttr().getValue().size() !=
2071 getCollapseDeviceTypeAttr().getValue().size())
2072 return emitOpError() <<
"collapse attribute count must match collapse"
2073 <<
" device_type count";
2075 return emitOpError()
2076 <<
"duplicate device_type found in collapseDeviceType attribute";
2079 if (!getGangOperands().empty()) {
2080 if (!getGangOperandsArgType())
2081 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
2082 <<
" when gang operands are present";
2084 if (getGangOperands().size() !=
2085 getGangOperandsArgTypeAttr().getValue().size())
2086 return emitOpError() <<
"gangOperandsArgType attribute count must match"
2087 <<
" gangOperands count";
2090 return emitOpError() <<
"duplicate device_type found in gang attribute";
2093 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
2094 getGangOperandsDeviceTypeAttr(),
"gang")))
2099 return emitOpError() <<
"duplicate device_type found in worker attribute";
2101 return emitOpError() <<
"duplicate device_type found in "
2102 "workerNumOperandsDeviceType attribute";
2104 getWorkerNumOperandsDeviceTypeAttr(),
2110 return emitOpError() <<
"duplicate device_type found in vector attribute";
2112 return emitOpError() <<
"duplicate device_type found in "
2113 "vectorOperandsDeviceType attribute";
2115 getVectorOperandsDeviceTypeAttr(),
2120 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
2121 getTileOperandsDeviceTypeAttr(),
"tile")))
2125 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2129 return emitError() <<
"only one of \"" << acc::LoopOp::getAutoAttrStrName()
2130 <<
"\", " << getIndependentAttrName() <<
", "
2132 <<
" can be present at the same time";
2137 for (
auto attr : getSeqAttr()) {
2138 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2139 if (hasVector(deviceTypeAttr.getValue()) ||
2140 getVectorValue(deviceTypeAttr.getValue()) ||
2141 hasWorker(deviceTypeAttr.getValue()) ||
2142 getWorkerValue(deviceTypeAttr.getValue()) ||
2143 hasGang(deviceTypeAttr.getValue()) ||
2144 getGangValue(mlir::acc::GangArgType::Num,
2145 deviceTypeAttr.getValue()) ||
2146 getGangValue(mlir::acc::GangArgType::Dim,
2147 deviceTypeAttr.getValue()) ||
2148 getGangValue(mlir::acc::GangArgType::Static,
2149 deviceTypeAttr.getValue()))
2151 <<
"gang, worker or vector cannot appear with the seq attr";
2155 if (failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2156 *
this, getPrivatizations(), getPrivateOperands(),
"private",
2157 "privatizations",
false)))
2160 if (failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2161 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
2162 "reductions",
false)))
2165 if (getCombined().has_value() &&
2166 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2167 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2168 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2169 return emitError(
"unexpected combined constructs attribute");
2173 if (getRegion().empty())
2174 return emitError(
"expected non-empty body.");
2179 unsigned LoopOp::getNumDataOperands() {
2180 return getReductionOperands().size() + getPrivateOperands().size();
2183 Value LoopOp::getDataOperand(
unsigned i) {
2184 unsigned numOptional =
2185 getLowerbound().size() + getUpperbound().size() + getStep().size();
2186 numOptional += getGangOperands().size();
2187 numOptional += getVectorOperands().size();
2188 numOptional += getWorkerNumOperands().size();
2189 numOptional += getTileOperands().size();
2190 numOptional += getCacheOperands().size();
2191 return getOperand(numOptional + i);
2196 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2200 bool LoopOp::hasIndependent() {
2204 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2210 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2218 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2220 getVectorOperands(), deviceType);
2225 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2233 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2235 getWorkerNumOperands(), deviceType);
2240 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2249 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2251 getTileOperandsSegments(), deviceType);
2254 std::optional<int64_t> LoopOp::getCollapseValue() {
2258 std::optional<int64_t>
2259 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2260 if (!getCollapseAttr())
2261 return std::nullopt;
2262 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2264 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2265 return intAttr.getValue().getZExtValue();
2267 return std::nullopt;
2270 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2274 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2275 mlir::acc::DeviceType deviceType) {
2276 if (getGangOperands().empty())
2278 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
2279 int32_t nbOperandsBefore = 0;
2280 for (
unsigned i = 0; i < *pos; ++i)
2281 nbOperandsBefore += (*getGangOperandsSegments())[i];
2284 .drop_front(nbOperandsBefore)
2285 .take_front((*getGangOperandsSegments())[*pos]);
2287 int32_t argTypeIdx = nbOperandsBefore;
2288 for (
auto value : values) {
2289 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2290 (*getGangOperandsArgType())[argTypeIdx]);
2291 if (gangArgTypeAttr.getValue() == gangArgType)
2301 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2306 return {&getRegion()};
2350 if (!regionArgs.empty()) {
2351 p << acc::LoopOp::getControlKeyword() <<
"(";
2352 llvm::interleaveComma(regionArgs, p,
2354 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
2355 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
2356 <<
" : " << stepType <<
") ";
2369 if (getOperands().empty() && !getDefaultAttr())
2370 return emitError(
"at least one operand or the default attribute "
2371 "must appear on the data operation");
2373 for (
mlir::Value operand : getDataClauseOperands())
2374 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2375 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2376 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2377 operand.getDefiningOp()))
2378 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
2381 if (failed(checkWaitAndAsyncConflict<acc::DataOp>(*
this)))
2387 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
2389 Value DataOp::getDataOperand(
unsigned i) {
2390 unsigned numOptional = getIfCond() ? 1 : 0;
2392 numOptional += getWaitOperands().size();
2393 return getOperand(numOptional + i);
2396 bool acc::DataOp::hasAsyncOnly() {
2400 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2408 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2415 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2424 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2426 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2427 getHasWaitDevnum(), deviceType);
2434 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2436 getWaitOperandsSegments(), getHasWaitDevnum(),
2448 if (getDataClauseOperands().empty())
2449 return emitError(
"at least one operand must be present in dataOperands on "
2450 "the exit data operation");
2454 if (getAsyncOperand() && getAsync())
2455 return emitError(
"async attribute cannot appear with asyncOperand");
2459 if (!getWaitOperands().empty() && getWait())
2460 return emitError(
"wait attribute cannot appear with waitOperands");
2462 if (getWaitDevnum() && getWaitOperands().empty())
2463 return emitError(
"wait_devnum cannot appear without waitOperands");
2468 unsigned ExitDataOp::getNumDataOperands() {
2469 return getDataClauseOperands().size();
2472 Value ExitDataOp::getDataOperand(
unsigned i) {
2473 unsigned numOptional = getIfCond() ? 1 : 0;
2474 numOptional += getAsyncOperand() ? 1 : 0;
2475 numOptional += getWaitDevnum() ? 1 : 0;
2476 return getOperand(getWaitOperands().size() + numOptional + i);
2481 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
2492 if (getDataClauseOperands().empty())
2493 return emitError(
"at least one operand must be present in dataOperands on "
2494 "the enter data operation");
2498 if (getAsyncOperand() && getAsync())
2499 return emitError(
"async attribute cannot appear with asyncOperand");
2503 if (!getWaitOperands().empty() && getWait())
2504 return emitError(
"wait attribute cannot appear with waitOperands");
2506 if (getWaitDevnum() && getWaitOperands().empty())
2507 return emitError(
"wait_devnum cannot appear without waitOperands");
2509 for (
mlir::Value operand : getDataClauseOperands())
2510 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
2511 operand.getDefiningOp()))
2512 return emitError(
"expect data entry operation as defining op");
2517 unsigned EnterDataOp::getNumDataOperands() {
2518 return getDataClauseOperands().size();
2521 Value EnterDataOp::getDataOperand(
unsigned i) {
2522 unsigned numOptional = getIfCond() ? 1 : 0;
2523 numOptional += getAsyncOperand() ? 1 : 0;
2524 numOptional += getWaitDevnum() ? 1 : 0;
2525 return getOperand(getWaitOperands().size() + numOptional + i);
2530 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
2549 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
2556 if (
Value writeVal = op.getWriteOpVal()) {
2566 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
2572 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
2573 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
2575 return dyn_cast<AtomicReadOp>(getSecondOp());
2578 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
2579 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
2581 return dyn_cast<AtomicWriteOp>(getSecondOp());
2584 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
2585 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
2587 return dyn_cast<AtomicUpdateOp>(getSecondOp());
2590 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
2596 template <
typename Op>
2597 static LogicalResult
2599 bool requireAtLeastOneOperand =
true) {
2600 if (operands.empty() && requireAtLeastOneOperand)
2603 "at least one operand must appear on the declare operation");
2606 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2607 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
2608 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
2609 operand.getDefiningOp()))
2611 "expect valid declare data entry operation or acc.getdeviceptr "
2615 assert(var &&
"declare operands can only be data entry operations which "
2617 std::optional<mlir::acc::DataClause> dataClauseOptional{
2619 assert(dataClauseOptional.has_value() &&
2620 "declare operands can only be data entry operations which must have "
2624 if (!var.getDefiningOp())
2628 auto declareAttribute{
2630 if (!declareAttribute)
2632 "expect declare attribute on variable in declare operation");
2634 auto declAttr = mlir::cast<mlir::acc::DeclareAttr>(declareAttribute);
2635 if (declAttr.getDataClause().getValue() != dataClauseOptional.value())
2637 "expect matching declare attribute on variable in declare operation");
2644 if (declAttr.getImplicit() &&
2647 "implicitness must match between declare op and flag on variable");
2681 acc::DeviceType dtype) {
2682 unsigned parallelism = 0;
2683 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
2684 parallelism += op.hasWorker(dtype) ? 1 : 0;
2685 parallelism += op.hasVector(dtype) ? 1 : 0;
2686 parallelism += op.hasSeq(dtype) ? 1 : 0;
2691 unsigned baseParallelism =
2694 if (baseParallelism > 1)
2695 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
2696 "be present at the same time";
2698 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
2700 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
2705 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
2706 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
2707 "be present at the same time";
2714 mlir::ArrayAttr &deviceTypes) {
2719 if (parser.parseAttribute(bindNameAttrs.emplace_back()))
2721 if (failed(parser.parseOptionalLSquare())) {
2722 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2723 parser.getContext(), mlir::acc::DeviceType::None));
2725 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2726 parser.parseRSquare())
2734 deviceTypes =
ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2740 std::optional<mlir::ArrayAttr> bindName,
2741 std::optional<mlir::ArrayAttr> deviceTypes) {
2742 llvm::interleaveComma(llvm::zip(*bindName, *deviceTypes), p,
2743 [&](
const auto &pair) {
2744 p << std::get<0>(pair);
2750 mlir::ArrayAttr &gang,
2751 mlir::ArrayAttr &gangDim,
2752 mlir::ArrayAttr &gangDimDeviceTypes) {
2755 gangDimDeviceTypeAttrs;
2756 bool needCommaBeforeOperands =
false;
2769 if (parser.parseAttribute(gangAttrs.emplace_back()))
2776 needCommaBeforeOperands =
true;
2779 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2783 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
2784 parser.parseColon() ||
2785 parser.parseAttribute(gangDimAttrs.emplace_back()))
2787 if (succeeded(parser.parseOptionalLSquare())) {
2788 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
2789 parser.parseRSquare())
2792 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2793 parser.getContext(), mlir::acc::DeviceType::None));
2799 if (failed(parser.parseRParen()))
2804 gangDimDeviceTypes =
2811 std::optional<mlir::ArrayAttr> gang,
2812 std::optional<mlir::ArrayAttr> gangDim,
2813 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
2816 gang->size() == 1) {
2817 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
2830 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
2831 [&](
const auto &pair) {
2832 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
2833 p << std::get<0>(pair);
2841 mlir::ArrayAttr &deviceTypes) {
2854 if (parser.parseAttribute(attributes.emplace_back()))
2868 std::optional<mlir::ArrayAttr> deviceTypes) {
2871 auto deviceTypeAttr =
2872 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
2882 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2890 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
2896 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
2902 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
2906 std::optional<llvm::StringRef> RoutineOp::getBindNameValue() {
2910 std::optional<llvm::StringRef>
2911 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
2913 return std::nullopt;
2914 if (
auto pos =
findSegment(*getBindNameDeviceType(), deviceType)) {
2915 auto attr = (*getBindName())[*pos];
2916 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
2917 return stringAttr.getValue();
2919 return std::nullopt;
2924 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
2928 std::optional<int64_t> RoutineOp::getGangDimValue() {
2932 std::optional<int64_t>
2933 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
2935 return std::nullopt;
2936 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
2937 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
2938 return intAttr.getInt();
2940 return std::nullopt;
2951 return emitOpError(
"cannot be nested in a compute operation");
2963 return emitOpError(
"cannot be nested in a compute operation");
2975 return emitOpError(
"cannot be nested in a compute operation");
2976 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
2977 return emitOpError(
"at least one default_async, device_num, or device_type "
2978 "operand must appear");
2988 if (getDataClauseOperands().empty())
2989 return emitError(
"at least one value must be present in dataOperands");
2992 getAsyncOperandsDeviceTypeAttr(),
2997 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2998 getWaitOperandsDeviceTypeAttr(),
"wait")))
3001 if (failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*
this)))
3004 for (
mlir::Value operand : getDataClauseOperands())
3005 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3006 operand.getDefiningOp()))
3007 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3013 unsigned UpdateOp::getNumDataOperands() {
3014 return getDataClauseOperands().size();
3017 Value UpdateOp::getDataOperand(
unsigned i) {
3019 numOptional += getIfCond() ? 1 : 0;
3020 return getOperand(getWaitOperands().size() + numOptional + i);
3025 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
3028 bool UpdateOp::hasAsyncOnly() {
3032 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3040 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3050 bool UpdateOp::hasWaitOnly() {
3054 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3063 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3065 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3066 getHasWaitDevnum(), deviceType);
3073 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3075 getWaitOperandsSegments(), getHasWaitDevnum(),
3086 if (getAsyncOperand() && getAsync())
3087 return emitError(
"async attribute cannot appear with asyncOperand");
3089 if (getWaitDevnum() && getWaitOperands().empty())
3090 return emitError(
"wait_devnum cannot appear without waitOperands");
3095 #define GET_OP_CLASSES
3096 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
3098 #define GET_ATTRDEF_CLASSES
3099 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
3101 #define GET_TYPEDEF_CLASSES
3102 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
3113 .Case<ACC_DATA_ENTRY_OPS>(
3114 [&](
auto entry) {
return entry.getVarPtr(); })
3115 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3116 [&](
auto exit) {
return exit.getVarPtr(); })
3134 [&](
auto entry) {
return entry.getVarType(); })
3135 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
3136 [&](
auto exit) {
return exit.getVarType(); })
3146 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
3147 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
3157 [&](
auto dataClause) {
return dataClause.getAccVar(); })
3166 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
3176 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
3178 dataClause.getBounds().begin(), dataClause.getBounds().end());
3190 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
3192 dataClause.getAsyncOperands().begin(),
3193 dataClause.getAsyncOperands().end());
3204 return dataClause.getAsyncOperandsDeviceTypeAttr();
3212 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
3219 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
3226 std::optional<mlir::acc::DataClause>
3231 .Case<ACC_DATA_ENTRY_OPS>(
3232 [&](
auto entry) {
return entry.getDataClause(); })
3240 [&](
auto entry) {
return entry.getImplicit(); })
3249 [&](
auto entry) {
return entry.getDataClauseOperands(); })
3251 return dataOperands;
3259 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
3261 return dataOperands;
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindName, mlir::ArrayAttr &deviceTypes)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindName, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static Type getElementType(Type type, ArrayRef< int32_t > indices, function_ref< InFlightDiagnostic(StringRef)> emitErrorFn)
Walks the given type hierarchy with the given indices, potentially down to component granularity,...
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues=std::nullopt)
Inline the operations of block 'source' into block 'dest' before the given position.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
static constexpr StringLiteral getDeclareAttrName()
Used to obtain the attribute name for declare.
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.