21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
33 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
37 static bool isScalarLikeType(
Type type) {
41 struct MemRefPointerLikeModel
42 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
45 return cast<MemRefType>(pointer).getElementType();
47 mlir::acc::VariableTypeCategory
50 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
51 return mappableTy.getTypeCategory(varPtr);
53 auto memrefTy = cast<MemRefType>(pointer);
54 if (!memrefTy.hasRank()) {
57 return mlir::acc::VariableTypeCategory::uncategorized;
60 if (memrefTy.getRank() == 0) {
61 if (isScalarLikeType(memrefTy.getElementType())) {
62 return mlir::acc::VariableTypeCategory::scalar;
66 return mlir::acc::VariableTypeCategory::uncategorized;
70 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
71 return mlir::acc::VariableTypeCategory::array;
75 struct LLVMPointerPointerLikeModel
76 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77 LLVM::LLVMPointerType> {
85 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
86 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
89 if (existingDeviceTypes)
90 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
92 if (newDeviceTypes.empty())
93 deviceTypes.push_back(
96 for (DeviceType DT : newDeviceTypes)
108 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
109 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
114 if (existingDeviceTypes)
115 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
117 if (newDeviceTypes.empty()) {
118 argCollection.
append(arguments);
119 segments.push_back(arguments.size());
120 deviceTypes.push_back(
124 for (DeviceType DT : newDeviceTypes) {
125 argCollection.
append(arguments);
126 segments.push_back(arguments.size());
134 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
135 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
139 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
140 newDeviceTypes, arguments,
141 argCollection, segments);
149 void OpenACCDialect::initialize() {
152 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
155 #define GET_ATTRDEF_LIST
156 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
159 #define GET_TYPEDEF_LIST
160 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
166 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
167 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
176 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
180 mlir::acc::DeviceType deviceType) {
184 for (
auto attr : *arrayAttr) {
185 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
186 if (deviceTypeAttr.getValue() == deviceType)
194 std::optional<mlir::ArrayAttr> deviceTypes) {
199 llvm::interleaveComma(*deviceTypes, p,
205 mlir::acc::DeviceType deviceType) {
206 unsigned segmentIdx = 0;
207 for (
auto attr : segments) {
208 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
209 if (deviceTypeAttr.getValue() == deviceType)
210 return std::make_optional(segmentIdx);
220 mlir::acc::DeviceType deviceType) {
222 return range.take_front(0);
223 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
224 int32_t nbOperandsBefore = 0;
225 for (
unsigned i = 0; i < *pos; ++i)
226 nbOperandsBefore += (*segments)[i];
227 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
229 return range.take_front(0);
236 std::optional<mlir::ArrayAttr> hasWaitDevnum,
237 mlir::acc::DeviceType deviceType) {
240 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
241 if (hasWaitDevnum->getValue()[*pos])
252 std::optional<mlir::ArrayAttr> hasWaitDevnum,
253 mlir::acc::DeviceType deviceType) {
258 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
259 if (hasWaitDevnum && *hasWaitDevnum) {
260 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
261 if (boolAttr.getValue())
262 return range.drop_front(1);
268 template <
typename Op>
270 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
272 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
277 op.hasAsyncOnly(dtype))
279 "asyncOnly attribute cannot appear with asyncOperand");
284 op.hasWaitOnly(dtype))
285 return op.
emitError(
"wait attribute cannot appear with waitOperands");
290 template <
typename Op>
293 return op.
emitError(
"must have var operand");
296 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
297 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
298 return op.
emitError(
"var must be mappable or pointer-like");
301 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
302 op.getVarType() == op.getVar().getType())
303 return op.
emitError(
"varType must capture the element type of var");
308 template <
typename Op>
310 if (op.getVar().getType() != op.getAccVar().getType())
311 return op.
emitError(
"input and output types must match");
316 template <
typename Op>
318 if (op.getModifiers() != acc::DataClauseModifier::none)
319 return op.
emitError(
"no data clause modifiers are allowed");
323 template <
typename Op>
326 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
328 "invalid data clause modifiers: " +
329 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
351 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
382 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
394 mlir::TypeAttr &varTypeAttr) {
411 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
413 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
422 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
430 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
431 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
433 if (typeToCheckAgainst != varType) {
444 auto extent = getExtent();
445 auto upperbound = getUpperbound();
446 if (!extent && !upperbound)
447 return emitError(
"expected extent or upperbound.");
457 "data clause associated with private operation must match its intent");
470 return emitError(
"data clause associated with firstprivate operation must "
484 return emitError(
"data clause associated with reduction operation must "
498 return emitError(
"data clause associated with deviceptr operation must "
515 "data clause associated with present operation must match its intent");
530 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
535 "data clause associated with copyin operation must match its intent"
536 " or specify original clause this operation was decomposed from");
542 acc::DataClauseModifier::always |
543 acc::DataClauseModifier::capture)))
548 bool acc::CopyinOp::isCopyinReadonly() {
549 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
550 acc::bitEnumContainsAny(getModifiers(),
551 acc::DataClauseModifier::readonly);
564 "data clause associated with create operation must match its intent"
565 " or specify original clause this operation was decomposed from");
573 acc::DataClauseModifier::always |
574 acc::DataClauseModifier::capture)))
579 bool acc::CreateOp::isCreateZero() {
581 return getDataClause() == acc::DataClause::acc_create_zero ||
583 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
591 return emitError(
"data clause associated with no_create operation must "
608 "data clause associated with attach operation must match its intent");
623 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
624 return emitError(
"data clause associated with device_resident operation "
625 "must match its intent");
642 "data clause associated with link operation must match its intent");
662 "data clause associated with copyout operation must match its intent"
663 " or specify original clause this operation was decomposed from");
665 return emitError(
"must have both host and device pointers");
671 acc::DataClauseModifier::always |
672 acc::DataClauseModifier::capture)))
677 bool acc::CopyoutOp::isCopyoutZero() {
678 return getDataClause() == acc::DataClause::acc_copyout_zero ||
679 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
694 getDataClause() != acc::DataClause::acc_declare_device_resident &&
697 "data clause associated with delete operation must match its intent"
698 " or specify original clause this operation was decomposed from");
700 return emitError(
"must have device pointer");
704 acc::DataClauseModifier::readonly |
705 acc::DataClauseModifier::always |
706 acc::DataClauseModifier::capture)))
719 "data clause associated with detach operation must match its intent"
720 " or specify original clause this operation was decomposed from");
722 return emitError(
"must have device pointer");
736 "data clause associated with host operation must match its intent"
737 " or specify original clause this operation was decomposed from");
739 return emitError(
"must have both host and device pointers");
756 "data clause associated with device operation must match its intent"
757 " or specify original clause this operation was decomposed from");
774 "data clause associated with use_device operation must match its intent"
775 " or specify original clause this operation was decomposed from");
793 "data clause associated with cache operation must match its intent"
794 " or specify original clause this operation was decomposed from");
804 bool acc::CacheOp::isCacheReadonly() {
805 return getDataClause() == acc::DataClause::acc_cache_readonly ||
806 acc::bitEnumContainsAny(getModifiers(),
807 acc::DataClauseModifier::readonly);
810 template <
typename StructureOp>
812 unsigned nRegions = 1) {
815 for (
unsigned i = 0; i < nRegions; ++i)
816 regions.push_back(state.addRegion());
818 for (
Region *region : regions)
826 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
833 template <
typename OpTy>
837 LogicalResult matchAndRewrite(OpTy op,
840 Value ifCond = op.getIfCond();
844 IntegerAttr constAttr;
847 if (constAttr.getInt())
848 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
860 assert(region.
hasOneBlock() &&
"expected single-block region");
872 template <
typename OpTy>
873 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
876 LogicalResult matchAndRewrite(OpTy op,
879 Value ifCond = op.getIfCond();
883 IntegerAttr constAttr;
886 if (constAttr.getInt())
887 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
902 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
904 if (optional && region.
empty())
908 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
912 return op->
emitOpError() <<
"expects " << regionName
915 << regionType <<
" type";
918 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
919 if (yieldOp.getOperands().size() != 1 ||
920 yieldOp.getOperands().getTypes()[0] != type)
921 return op->
emitOpError() <<
"expects " << regionName
923 "yield a value of the "
924 << regionType <<
" type";
930 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
932 "privatization",
"init",
getType(),
936 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
946 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
948 "privatization",
"init",
getType(),
952 if (getCopyRegion().empty())
953 return emitOpError() <<
"expects non-empty copy region";
958 return emitOpError() <<
"expects copy region with two arguments of the "
959 "privatization type";
961 if (getDestroyRegion().empty())
965 "privatization",
"destroy",
976 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
982 if (getCombinerRegion().empty())
983 return emitOpError() <<
"expects non-empty combiner region";
985 Block &reductionBlock = getCombinerRegion().
front();
989 return emitOpError() <<
"expects combiner region with the first two "
990 <<
"arguments of the reduction type";
992 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
993 if (yieldOp.getOperands().size() != 1 ||
994 yieldOp.getOperands().getTypes()[0] !=
getType())
995 return emitOpError() <<
"expects combiner region to yield a value "
996 "of the reduction type";
1012 if (parser.parseAttribute(attributes.emplace_back()) ||
1013 parser.parseArrow() ||
1014 parser.parseOperand(operands.emplace_back()) ||
1015 parser.parseColonType(types.emplace_back()))
1029 std::optional<mlir::ArrayAttr> attributes) {
1030 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
1031 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
1032 << std::get<1>(it).getType();
1041 template <
typename Op>
1045 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1046 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1047 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1048 operand.getDefiningOp()))
1050 "expect data entry/exit operation or acc.getdeviceptr "
1055 template <
typename Op>
1056 static LogicalResult
1059 llvm::StringRef symbolName,
bool checkOperandType =
true) {
1060 if (!operands.empty()) {
1061 if (!attributes || attributes->size() != operands.size())
1063 <<
"expected as many " << symbolName <<
" symbol reference as "
1064 << operandName <<
" operands";
1068 <<
"unexpected " << symbolName <<
" symbol reference";
1073 for (
auto args : llvm::zip(operands, *attributes)) {
1076 if (!set.insert(operand).second)
1078 << operandName <<
" operand appears more than once";
1081 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1082 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1085 <<
"expected symbol reference " << symbolRef <<
" to point to a "
1086 << operandName <<
" declaration";
1088 if (checkOperandType && decl.getType() && decl.getType() != varType)
1089 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
1090 <<
") to be the same type as " << operandName
1091 <<
" declaration (" << decl.getType() <<
")";
1097 unsigned ParallelOp::getNumDataOperands() {
1098 return getReductionOperands().size() + getPrivateOperands().size() +
1099 getFirstprivateOperands().size() + getDataClauseOperands().size();
1102 Value ParallelOp::getDataOperand(
unsigned i) {
1104 numOptional += getNumGangs().size();
1105 numOptional += getNumWorkers().size();
1106 numOptional += getVectorLength().size();
1107 numOptional += getIfCond() ? 1 : 0;
1108 numOptional += getSelfCond() ? 1 : 0;
1109 return getOperand(getWaitOperands().size() + numOptional + i);
1112 template <
typename Op>
1114 ArrayAttr deviceTypes,
1115 llvm::StringRef keyword) {
1116 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1117 return op.
emitOpError() << keyword <<
" operands count must match "
1118 << keyword <<
" device_type count";
1122 template <
typename Op>
1125 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1126 std::size_t numOperandsInSegments = 0;
1127 std::size_t nbOfSegments = 0;
1130 for (
auto segCount : segments.
asArrayRef()) {
1131 if (maxInSegment != 0 && segCount > maxInSegment)
1132 return op.
emitOpError() << keyword <<
" expects a maximum of "
1133 << maxInSegment <<
" values per segment";
1134 numOperandsInSegments += segCount;
1139 if ((numOperandsInSegments != operands.size()) ||
1140 (!deviceTypes && !operands.empty()))
1142 << keyword <<
" operand count does not match count in segments";
1143 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1145 << keyword <<
" segment count does not match device_type count";
1150 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1151 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
1152 "privatizations",
false)))
1154 if (
failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1155 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1156 "firstprivate",
"firstprivatizations",
false)))
1158 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1159 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1160 "reductions",
false)))
1164 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1165 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1169 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1170 getWaitOperandsDeviceTypeAttr(),
"wait")))
1174 getNumWorkersDeviceTypeAttr(),
1179 getVectorLengthDeviceTypeAttr(),
1184 getAsyncOperandsDeviceTypeAttr(),
1188 if (
failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*
this)))
1191 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
1197 mlir::acc::DeviceType deviceType) {
1200 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1205 bool acc::ParallelOp::hasAsyncOnly() {
1209 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1217 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1222 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1227 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1232 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1237 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1239 getVectorLength(), deviceType);
1247 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1249 getNumGangsSegments(), deviceType);
1252 bool acc::ParallelOp::hasWaitOnly() {
1256 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1265 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1267 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1268 getHasWaitDevnum(), deviceType);
1275 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1277 getWaitOperandsSegments(), getHasWaitDevnum(),
1293 odsBuilder, odsState, asyncOperands,
nullptr,
1294 nullptr, waitOperands,
nullptr,
1296 nullptr, numGangs,
nullptr,
1297 nullptr, numWorkers,
1298 nullptr, vectorLength,
1299 nullptr, ifCond, selfCond,
1300 nullptr, reductionOperands,
nullptr,
1301 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
1302 nullptr, dataClauseOperands,
1306 void acc::ParallelOp::addNumWorkersOperand(
1309 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1310 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1311 getNumWorkersMutable()));
1313 void acc::ParallelOp::addVectorLengthOperand(
1316 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1317 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1318 getVectorLengthMutable()));
1321 void acc::ParallelOp::addAsyncOnly(
1323 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1324 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1327 void acc::ParallelOp::addAsyncOperand(
1330 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1331 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1332 getAsyncOperandsMutable()));
1335 void acc::ParallelOp::addNumGangsOperands(
1339 if (getNumGangsSegments())
1340 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1342 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1343 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1344 getNumGangsMutable(), segments));
1346 setNumGangsSegments(segments);
1348 void acc::ParallelOp::addWaitOnly(
1350 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1351 effectiveDeviceTypes));
1353 void acc::ParallelOp::addWaitOperands(
1358 if (getWaitOperandsSegments())
1359 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1361 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1362 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1363 getWaitOperandsMutable(), segments));
1364 setWaitOperandsSegments(segments);
1367 if (getHasWaitDevnumAttr())
1368 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1371 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1376 void acc::ParallelOp::addPrivatization(
MLIRContext *context,
1377 mlir::acc::PrivateOp op,
1378 mlir::acc::PrivateRecipeOp recipe) {
1379 getPrivateOperandsMutable().append(op.getResult());
1383 if (getPrivatizationRecipesAttr())
1384 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
1391 void acc::ParallelOp::addFirstPrivatization(
1392 MLIRContext *context, mlir::acc::FirstprivateOp op,
1393 mlir::acc::FirstprivateRecipeOp recipe) {
1394 getFirstprivateOperandsMutable().append(op.getResult());
1398 if (getFirstprivatizationRecipesAttr())
1399 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
1406 void acc::ParallelOp::addReduction(
MLIRContext *context,
1407 mlir::acc::ReductionOp op,
1408 mlir::acc::ReductionRecipeOp recipe) {
1409 getReductionOperandsMutable().append(op.getResult());
1413 if (getReductionRecipesAttr())
1414 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
1433 int32_t crtOperandsSize = operands.size();
1436 if (parser.parseOperand(operands.emplace_back()) ||
1437 parser.parseColonType(types.emplace_back()))
1442 seg.push_back(operands.size() - crtOperandsSize);
1466 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1468 p <<
" [" << attr <<
"]";
1473 std::optional<mlir::ArrayAttr> deviceTypes,
1474 std::optional<mlir::DenseI32ArrayAttr> segments) {
1476 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1478 llvm::interleaveComma(
1479 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1480 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1500 int32_t crtOperandsSize = operands.size();
1504 if (parser.parseOperand(operands.emplace_back()) ||
1505 parser.parseColonType(types.emplace_back()))
1511 seg.push_back(operands.size() - crtOperandsSize);
1537 std::optional<mlir::DenseI32ArrayAttr> segments) {
1539 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1541 llvm::interleaveComma(
1542 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1543 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1556 mlir::ArrayAttr &keywordOnly) {
1560 bool needCommaBeforeOperands =
false;
1573 if (parser.parseAttribute(keywordAttrs.emplace_back()))
1580 needCommaBeforeOperands =
true;
1590 int32_t crtOperandsSize = operands.size();
1602 if (parser.parseOperand(operands.emplace_back()) ||
1603 parser.parseColonType(types.emplace_back()))
1609 seg.push_back(operands.size() - crtOperandsSize);
1638 if (attrs->size() != 1)
1640 if (
auto deviceTypeAttr =
1641 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1648 std::optional<mlir::ArrayAttr> deviceTypes,
1649 std::optional<mlir::DenseI32ArrayAttr> segments,
1650 std::optional<mlir::ArrayAttr> hasDevNum,
1651 std::optional<mlir::ArrayAttr> keywordOnly) {
1664 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1666 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1667 if (boolAttr && boolAttr.getValue())
1669 llvm::interleaveComma(
1670 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1671 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1688 if (parser.parseOperand(operands.emplace_back()) ||
1689 parser.parseColonType(types.emplace_back()))
1691 if (succeeded(parser.parseOptionalLSquare())) {
1692 if (parser.parseAttribute(attributes.emplace_back()) ||
1693 parser.parseRSquare())
1696 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1697 parser.getContext(), mlir::acc::DeviceType::None));
1711 std::optional<mlir::ArrayAttr> deviceTypes) {
1714 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
1715 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
1724 mlir::ArrayAttr &keywordOnlyDeviceType) {
1727 bool needCommaBeforeOperands =
false;
1733 keywordOnlyDeviceType =
1742 if (parser.parseAttribute(
1743 keywordOnlyDeviceTypeAttributes.emplace_back()))
1750 needCommaBeforeOperands =
true;
1758 if (parser.parseOperand(operands.emplace_back()) ||
1759 parser.parseColonType(types.emplace_back()))
1761 if (succeeded(parser.parseOptionalLSquare())) {
1762 if (parser.parseAttribute(attributes.emplace_back()) ||
1763 parser.parseRSquare())
1766 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1767 parser.getContext(), mlir::acc::DeviceType::None));
1773 if (
failed(parser.parseRParen()))
1785 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1787 if (operands.begin() == operands.end() &&
1803 std::optional<OpAsmParser::UnresolvedOperand> &operand,
1804 mlir::Type &operandType, mlir::UnitAttr &attr) {
1827 std::optional<mlir::Value> operand,
1829 mlir::UnitAttr attr) {
1851 if (parser.parseOperand(operands.emplace_back()))
1859 if (parser.parseType(types.emplace_back()))
1874 mlir::UnitAttr attr) {
1879 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
1881 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
1887 mlir::acc::CombinedConstructsTypeAttr &attr) {
1890 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1893 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1896 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1899 "expected compute construct name");
1907 mlir::acc::CombinedConstructsTypeAttr attr) {
1909 switch (attr.getValue()) {
1910 case mlir::acc::CombinedConstructsType::KernelsLoop:
1913 case mlir::acc::CombinedConstructsType::ParallelLoop:
1916 case mlir::acc::CombinedConstructsType::SerialLoop:
1927 unsigned SerialOp::getNumDataOperands() {
1928 return getReductionOperands().size() + getPrivateOperands().size() +
1929 getFirstprivateOperands().size() + getDataClauseOperands().size();
1932 Value SerialOp::getDataOperand(
unsigned i) {
1934 numOptional += getIfCond() ? 1 : 0;
1935 numOptional += getSelfCond() ? 1 : 0;
1936 return getOperand(getWaitOperands().size() + numOptional + i);
1939 bool acc::SerialOp::hasAsyncOnly() {
1943 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1951 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1956 bool acc::SerialOp::hasWaitOnly() {
1960 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1969 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1971 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1972 getHasWaitDevnum(), deviceType);
1979 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1981 getWaitOperandsSegments(), getHasWaitDevnum(),
1986 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1987 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
1988 "privatizations",
false)))
1990 if (
failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1991 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1992 "firstprivate",
"firstprivatizations",
false)))
1994 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1995 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1996 "reductions",
false)))
2000 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2001 getWaitOperandsDeviceTypeAttr(),
"wait")))
2005 getAsyncOperandsDeviceTypeAttr(),
2009 if (
failed(checkWaitAndAsyncConflict<acc::SerialOp>(*
this)))
2012 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
2015 void acc::SerialOp::addAsyncOnly(
2017 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2018 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2021 void acc::SerialOp::addAsyncOperand(
2024 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2025 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2026 getAsyncOperandsMutable()));
2029 void acc::SerialOp::addWaitOnly(
2031 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2032 effectiveDeviceTypes));
2034 void acc::SerialOp::addWaitOperands(
2039 if (getWaitOperandsSegments())
2040 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2042 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2043 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2044 getWaitOperandsMutable(), segments));
2045 setWaitOperandsSegments(segments);
2048 if (getHasWaitDevnumAttr())
2049 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2052 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2057 void acc::SerialOp::addPrivatization(
MLIRContext *context,
2058 mlir::acc::PrivateOp op,
2059 mlir::acc::PrivateRecipeOp recipe) {
2060 getPrivateOperandsMutable().append(op.getResult());
2064 if (getPrivatizationRecipesAttr())
2065 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
2072 void acc::SerialOp::addFirstPrivatization(
2073 MLIRContext *context, mlir::acc::FirstprivateOp op,
2074 mlir::acc::FirstprivateRecipeOp recipe) {
2075 getFirstprivateOperandsMutable().append(op.getResult());
2079 if (getFirstprivatizationRecipesAttr())
2080 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
2087 void acc::SerialOp::addReduction(
MLIRContext *context,
2088 mlir::acc::ReductionOp op,
2089 mlir::acc::ReductionRecipeOp recipe) {
2090 getReductionOperandsMutable().append(op.getResult());
2094 if (getReductionRecipesAttr())
2095 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
2106 unsigned KernelsOp::getNumDataOperands() {
2107 return getDataClauseOperands().size();
2110 Value KernelsOp::getDataOperand(
unsigned i) {
2112 numOptional += getWaitOperands().size();
2113 numOptional += getNumGangs().size();
2114 numOptional += getNumWorkers().size();
2115 numOptional += getVectorLength().size();
2116 numOptional += getIfCond() ? 1 : 0;
2117 numOptional += getSelfCond() ? 1 : 0;
2118 return getOperand(numOptional + i);
2121 bool acc::KernelsOp::hasAsyncOnly() {
2125 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2133 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2138 mlir::Value acc::KernelsOp::getNumWorkersValue() {
2143 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2148 mlir::Value acc::KernelsOp::getVectorLengthValue() {
2153 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2155 getVectorLength(), deviceType);
2163 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2165 getNumGangsSegments(), deviceType);
2168 bool acc::KernelsOp::hasWaitOnly() {
2172 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2181 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2183 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2184 getHasWaitDevnum(), deviceType);
2191 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2193 getWaitOperandsSegments(), getHasWaitDevnum(),
2199 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2200 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2204 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2205 getWaitOperandsDeviceTypeAttr(),
"wait")))
2209 getNumWorkersDeviceTypeAttr(),
2214 getVectorLengthDeviceTypeAttr(),
2219 getAsyncOperandsDeviceTypeAttr(),
2223 if (
failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*
this)))
2226 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
2229 void acc::KernelsOp::addNumWorkersOperand(
2232 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2233 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2234 getNumWorkersMutable()));
2237 void acc::KernelsOp::addVectorLengthOperand(
2240 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2241 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2242 getVectorLengthMutable()));
2244 void acc::KernelsOp::addAsyncOnly(
2246 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2247 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2250 void acc::KernelsOp::addAsyncOperand(
2253 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2254 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2255 getAsyncOperandsMutable()));
2258 void acc::KernelsOp::addNumGangsOperands(
2262 if (getNumGangsSegmentsAttr())
2263 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2265 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2266 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2267 getNumGangsMutable(), segments));
2269 setNumGangsSegments(segments);
2272 void acc::KernelsOp::addWaitOnly(
2274 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2275 effectiveDeviceTypes));
2277 void acc::KernelsOp::addWaitOperands(
2282 if (getWaitOperandsSegments())
2283 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2285 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2286 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2287 getWaitOperandsMutable(), segments));
2288 setWaitOperandsSegments(segments);
2291 if (getHasWaitDevnumAttr())
2292 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2295 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2305 if (getDataClauseOperands().empty())
2306 return emitError(
"at least one operand must appear on the host_data "
2309 for (
mlir::Value operand : getDataClauseOperands())
2310 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2311 return emitError(
"expect data entry operation as defining op");
2317 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2329 bool &needCommaBetweenValues,
bool &newValue) {
2336 attributes.push_back(gangArgType);
2337 needCommaBetweenValues =
true;
2348 mlir::ArrayAttr &gangOnlyDeviceType) {
2353 bool needCommaBetweenValues =
false;
2354 bool needCommaBeforeOperands =
false;
2360 gangOnlyDeviceType =
2369 if (parser.parseAttribute(
2370 gangOnlyDeviceTypeAttributes.emplace_back()))
2377 needCommaBeforeOperands =
true;
2381 mlir::acc::GangArgType::Num);
2383 mlir::acc::GangArgType::Dim);
2385 parser.
getContext(), mlir::acc::GangArgType::Static);
2388 if (needCommaBeforeOperands) {
2389 needCommaBeforeOperands =
false;
2396 int32_t crtOperandsSize = gangOperands.size();
2398 bool newValue =
false;
2399 bool needValue =
false;
2400 if (needCommaBetweenValues) {
2408 gangOperands, gangOperandsType,
2409 gangArgTypeAttributes, argNum,
2410 needCommaBetweenValues, newValue)))
2413 gangOperands, gangOperandsType,
2414 gangArgTypeAttributes, argDim,
2415 needCommaBetweenValues, newValue)))
2418 gangOperands, gangOperandsType,
2419 gangArgTypeAttributes, argStatic,
2420 needCommaBetweenValues, newValue)))
2423 if (!newValue && needValue) {
2425 "new value expected after comma");
2433 if (gangOperands.empty())
2436 "expect at least one of num, dim or static values");
2442 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
2450 seg.push_back(gangOperands.size() - crtOperandsSize);
2458 gangArgTypeAttributes.end());
2463 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2472 std::optional<mlir::ArrayAttr> gangArgTypes,
2473 std::optional<mlir::ArrayAttr> deviceTypes,
2474 std::optional<mlir::DenseI32ArrayAttr> segments,
2475 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2477 if (operands.begin() == operands.end() &&
2492 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2494 llvm::interleaveComma(
2495 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2496 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2497 (*gangArgTypes)[opIdx]);
2498 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2499 p << LoopOp::getGangNumKeyword();
2500 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2501 p << LoopOp::getGangDimKeyword();
2502 else if (gangArgTypeAttr.getValue() ==
2503 mlir::acc::GangArgType::Static)
2504 p << LoopOp::getGangStaticKeyword();
2505 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
2516 std::optional<mlir::ArrayAttr> segments,
2517 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2520 for (
auto attr : *segments) {
2521 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2522 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2530 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2533 for (
auto attr : deviceTypes) {
2534 auto deviceTypeAttr =
2535 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2536 if (!deviceTypeAttr)
2538 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2545 if (getUpperbound().size() != getStep().size())
2546 return emitError() <<
"number of upperbounds expected to be the same as "
2549 if (getUpperbound().size() != getLowerbound().size())
2550 return emitError() <<
"number of upperbounds expected to be the same as "
2551 "number of lowerbounds";
2553 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2554 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2555 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
2556 <<
" as upperbound size";
2559 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2560 return emitOpError() <<
"collapse device_type attr must be define when"
2561 <<
" collapse attr is present";
2563 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2564 getCollapseAttr().getValue().size() !=
2565 getCollapseDeviceTypeAttr().getValue().size())
2566 return emitOpError() <<
"collapse attribute count must match collapse"
2567 <<
" device_type count";
2569 return emitOpError()
2570 <<
"duplicate device_type found in collapseDeviceType attribute";
2573 if (!getGangOperands().empty()) {
2574 if (!getGangOperandsArgType())
2575 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
2576 <<
" when gang operands are present";
2578 if (getGangOperands().size() !=
2579 getGangOperandsArgTypeAttr().getValue().size())
2580 return emitOpError() <<
"gangOperandsArgType attribute count must match"
2581 <<
" gangOperands count";
2584 return emitOpError() <<
"duplicate device_type found in gang attribute";
2587 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
2588 getGangOperandsDeviceTypeAttr(),
"gang")))
2593 return emitOpError() <<
"duplicate device_type found in worker attribute";
2595 return emitOpError() <<
"duplicate device_type found in "
2596 "workerNumOperandsDeviceType attribute";
2598 getWorkerNumOperandsDeviceTypeAttr(),
2604 return emitOpError() <<
"duplicate device_type found in vector attribute";
2606 return emitOpError() <<
"duplicate device_type found in "
2607 "vectorOperandsDeviceType attribute";
2609 getVectorOperandsDeviceTypeAttr(),
2614 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
2615 getTileOperandsDeviceTypeAttr(),
"tile")))
2619 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2623 return emitError() <<
"only one of auto, independent, seq can be present "
2629 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
2632 bool hasDefaultSeq =
2634 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2637 bool hasDefaultIndependent =
2638 getIndependentAttr()
2640 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2643 bool hasDefaultAuto =
2645 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2648 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
2650 <<
"at least one of auto, independent, seq must be present";
2655 for (
auto attr : getSeqAttr()) {
2656 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2657 if (hasVector(deviceTypeAttr.getValue()) ||
2658 getVectorValue(deviceTypeAttr.getValue()) ||
2659 hasWorker(deviceTypeAttr.getValue()) ||
2660 getWorkerValue(deviceTypeAttr.getValue()) ||
2661 hasGang(deviceTypeAttr.getValue()) ||
2662 getGangValue(mlir::acc::GangArgType::Num,
2663 deviceTypeAttr.getValue()) ||
2664 getGangValue(mlir::acc::GangArgType::Dim,
2665 deviceTypeAttr.getValue()) ||
2666 getGangValue(mlir::acc::GangArgType::Static,
2667 deviceTypeAttr.getValue()))
2668 return emitError() <<
"gang, worker or vector cannot appear with seq";
2672 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2673 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
2674 "privatizations",
false)))
2677 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2678 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
2679 "reductions",
false)))
2682 if (getCombined().has_value() &&
2683 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2684 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2685 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2686 return emitError(
"unexpected combined constructs attribute");
2690 if (getRegion().empty())
2691 return emitError(
"expected non-empty body.");
2694 if (isContainerLike()) {
2697 uint64_t collapseCount = getCollapseValue().value_or(1);
2698 if (getCollapseAttr()) {
2699 for (
auto collapseEntry : getCollapseAttr()) {
2700 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
2701 if (intAttr.getValue().getZExtValue() > collapseCount)
2702 collapseCount = intAttr.getValue().getZExtValue();
2710 bool foundSibling =
false;
2712 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
2714 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
2716 foundSibling =
true;
2721 expectedParent = op;
2724 if (collapseCount == 0)
2730 return emitError(
"found sibling loops inside container-like acc.loop");
2731 if (collapseCount != 0)
2732 return emitError(
"failed to find enough loop-like operations inside "
2733 "container-like acc.loop");
2739 unsigned LoopOp::getNumDataOperands() {
2740 return getReductionOperands().size() + getPrivateOperands().size();
2743 Value LoopOp::getDataOperand(
unsigned i) {
2744 unsigned numOptional =
2745 getLowerbound().size() + getUpperbound().size() + getStep().size();
2746 numOptional += getGangOperands().size();
2747 numOptional += getVectorOperands().size();
2748 numOptional += getWorkerNumOperands().size();
2749 numOptional += getTileOperands().size();
2750 numOptional += getCacheOperands().size();
2751 return getOperand(numOptional + i);
2756 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2760 bool LoopOp::hasIndependent() {
2764 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2770 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2778 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2780 getVectorOperands(), deviceType);
2785 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2793 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2795 getWorkerNumOperands(), deviceType);
2800 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2809 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2811 getTileOperandsSegments(), deviceType);
2814 std::optional<int64_t> LoopOp::getCollapseValue() {
2818 std::optional<int64_t>
2819 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2820 if (!getCollapseAttr())
2821 return std::nullopt;
2822 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2824 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2825 return intAttr.getValue().getZExtValue();
2827 return std::nullopt;
2830 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2834 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2835 mlir::acc::DeviceType deviceType) {
2836 if (getGangOperands().empty())
2838 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
2839 int32_t nbOperandsBefore = 0;
2840 for (
unsigned i = 0; i < *pos; ++i)
2841 nbOperandsBefore += (*getGangOperandsSegments())[i];
2844 .drop_front(nbOperandsBefore)
2845 .take_front((*getGangOperandsSegments())[*pos]);
2847 int32_t argTypeIdx = nbOperandsBefore;
2848 for (
auto value : values) {
2849 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2850 (*getGangOperandsArgType())[argTypeIdx]);
2851 if (gangArgTypeAttr.getValue() == gangArgType)
2861 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2866 return {&getRegion()};
2910 if (!regionArgs.empty()) {
2911 p << acc::LoopOp::getControlKeyword() <<
"(";
2912 llvm::interleaveComma(regionArgs, p,
2914 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
2915 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
2916 <<
" : " << stepType <<
") ";
2923 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
2924 effectiveDeviceTypes));
2927 void acc::LoopOp::addIndependent(
2929 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
2930 context, getIndependentAttr(), effectiveDeviceTypes));
2935 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
2936 effectiveDeviceTypes));
2939 void acc::LoopOp::setCollapseForDeviceTypes(
2941 llvm::APInt value) {
2945 assert((getCollapseAttr() ==
nullptr) ==
2946 (getCollapseDeviceTypeAttr() ==
nullptr));
2947 assert(value.getBitWidth() == 64);
2949 if (getCollapseAttr()) {
2950 for (
const auto &existing :
2951 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
2952 newValues.push_back(std::get<0>(existing));
2953 newDeviceTypes.push_back(std::get<1>(existing));
2957 if (effectiveDeviceTypes.empty()) {
2960 newValues.push_back(
2962 newDeviceTypes.push_back(
2965 for (DeviceType DT : effectiveDeviceTypes) {
2966 newValues.push_back(
2973 setCollapseDeviceTypeAttr(
ArrayAttr::get(context, newDeviceTypes));
2976 void acc::LoopOp::setTileForDeviceTypes(
2980 if (getTileOperandsSegments())
2981 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
2983 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2984 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2985 getTileOperandsMutable(), segments));
2987 setTileOperandsSegments(segments);
2990 void acc::LoopOp::addVectorOperand(
2993 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2994 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2995 newValue, getVectorOperandsMutable()));
2998 void acc::LoopOp::addEmptyVector(
3000 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3001 effectiveDeviceTypes));
3004 void acc::LoopOp::addWorkerNumOperand(
3007 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3008 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3009 newValue, getWorkerNumOperandsMutable()));
3012 void acc::LoopOp::addEmptyWorker(
3014 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3015 effectiveDeviceTypes));
3018 void acc::LoopOp::addEmptyGang(
3020 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3021 effectiveDeviceTypes));
3024 bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3025 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3026 return attr.getValue() == dt;
3028 auto testFromArr = [=](ArrayAttr arr) ->
bool {
3029 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3032 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3034 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3036 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3042 bool acc::LoopOp::hasDefaultGangWorkerVector() {
3043 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3044 hasGang() || getGangValue(GangArgType::Num) ||
3045 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3049 acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3050 if (hasSeq(deviceType))
3051 return LoopParMode::loop_seq;
3052 if (hasAuto(deviceType))
3053 return LoopParMode::loop_auto;
3054 if (hasIndependent(deviceType))
3055 return LoopParMode::loop_independent;
3057 return LoopParMode::loop_seq;
3059 return LoopParMode::loop_auto;
3060 assert(hasIndependent() &&
3061 "loop must have default auto, seq, or independent");
3062 return LoopParMode::loop_independent;
3065 void acc::LoopOp::addGangOperands(
3070 getGangOperandsSegments())
3071 llvm::copy(*existingSegments, std::back_inserter(segments));
3073 unsigned beforeCount = segments.size();
3075 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3076 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3077 getGangOperandsMutable(), segments));
3079 setGangOperandsSegments(segments);
3086 unsigned numAdded = segments.size() - beforeCount;
3090 if (getGangOperandsArgTypeAttr())
3091 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3093 for (
auto i : llvm::index_range(0u, numAdded)) {
3094 llvm::transform(argTypes, std::back_inserter(gangTypes),
3095 [=](mlir::acc::GangArgType gangTy) {
3105 void acc::LoopOp::addPrivatization(
MLIRContext *context,
3106 mlir::acc::PrivateOp op,
3107 mlir::acc::PrivateRecipeOp recipe) {
3108 getPrivateOperandsMutable().append(op.getResult());
3112 if (getPrivatizationRecipesAttr())
3113 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
3120 void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3121 mlir::acc::ReductionRecipeOp recipe) {
3122 getReductionOperandsMutable().append(op.getResult());
3126 if (getReductionRecipesAttr())
3127 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
3142 if (getOperands().empty() && !getDefaultAttr())
3143 return emitError(
"at least one operand or the default attribute "
3144 "must appear on the data operation");
3146 for (
mlir::Value operand : getDataClauseOperands())
3147 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3148 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3149 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3150 operand.getDefiningOp()))
3151 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3154 if (
failed(checkWaitAndAsyncConflict<acc::DataOp>(*
this)))
3160 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3162 Value DataOp::getDataOperand(
unsigned i) {
3163 unsigned numOptional = getIfCond() ? 1 : 0;
3165 numOptional += getWaitOperands().size();
3166 return getOperand(numOptional + i);
3169 bool acc::DataOp::hasAsyncOnly() {
3173 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3181 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3188 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3197 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3199 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3200 getHasWaitDevnum(), deviceType);
3207 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3209 getWaitOperandsSegments(), getHasWaitDevnum(),
3213 void acc::DataOp::addAsyncOnly(
3215 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3216 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3219 void acc::DataOp::addAsyncOperand(
3222 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3223 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3224 getAsyncOperandsMutable()));
3227 void acc::DataOp::addWaitOnly(
MLIRContext *context,
3229 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3230 effectiveDeviceTypes));
3233 void acc::DataOp::addWaitOperands(
3238 if (getWaitOperandsSegments())
3239 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3241 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3242 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3243 getWaitOperandsMutable(), segments));
3244 setWaitOperandsSegments(segments);
3247 if (getHasWaitDevnumAttr())
3248 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3251 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3264 if (getDataClauseOperands().empty())
3265 return emitError(
"at least one operand must be present in dataOperands on "
3266 "the exit data operation");
3270 if (getAsyncOperand() && getAsync())
3271 return emitError(
"async attribute cannot appear with asyncOperand");
3275 if (!getWaitOperands().empty() && getWait())
3276 return emitError(
"wait attribute cannot appear with waitOperands");
3278 if (getWaitDevnum() && getWaitOperands().empty())
3279 return emitError(
"wait_devnum cannot appear without waitOperands");
3284 unsigned ExitDataOp::getNumDataOperands() {
3285 return getDataClauseOperands().size();
3288 Value ExitDataOp::getDataOperand(
unsigned i) {
3289 unsigned numOptional = getIfCond() ? 1 : 0;
3290 numOptional += getAsyncOperand() ? 1 : 0;
3291 numOptional += getWaitDevnum() ? 1 : 0;
3292 return getOperand(getWaitOperands().size() + numOptional + i);
3297 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3300 void ExitDataOp::addAsyncOnly(
MLIRContext *context,
3302 assert(effectiveDeviceTypes.empty());
3303 assert(!getAsyncAttr());
3304 assert(!getAsyncOperand());
3309 void ExitDataOp::addAsyncOperand(
3312 assert(effectiveDeviceTypes.empty());
3313 assert(!getAsyncAttr());
3314 assert(!getAsyncOperand());
3316 getAsyncOperandMutable().append(newValue);
3319 void ExitDataOp::addWaitOnly(
MLIRContext *context,
3321 assert(effectiveDeviceTypes.empty());
3322 assert(!getWaitAttr());
3323 assert(getWaitOperands().empty());
3324 assert(!getWaitDevnum());
3329 void ExitDataOp::addWaitOperands(
3332 assert(effectiveDeviceTypes.empty());
3333 assert(!getWaitAttr());
3334 assert(getWaitOperands().empty());
3335 assert(!getWaitDevnum());
3340 getWaitDevnumMutable().append(newValues.front());
3341 newValues = newValues.drop_front();
3344 getWaitOperandsMutable().append(newValues);
3355 if (getDataClauseOperands().empty())
3356 return emitError(
"at least one operand must be present in dataOperands on "
3357 "the enter data operation");
3361 if (getAsyncOperand() && getAsync())
3362 return emitError(
"async attribute cannot appear with asyncOperand");
3366 if (!getWaitOperands().empty() && getWait())
3367 return emitError(
"wait attribute cannot appear with waitOperands");
3369 if (getWaitDevnum() && getWaitOperands().empty())
3370 return emitError(
"wait_devnum cannot appear without waitOperands");
3372 for (
mlir::Value operand : getDataClauseOperands())
3373 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3374 operand.getDefiningOp()))
3375 return emitError(
"expect data entry operation as defining op");
3380 unsigned EnterDataOp::getNumDataOperands() {
3381 return getDataClauseOperands().size();
3384 Value EnterDataOp::getDataOperand(
unsigned i) {
3385 unsigned numOptional = getIfCond() ? 1 : 0;
3386 numOptional += getAsyncOperand() ? 1 : 0;
3387 numOptional += getWaitDevnum() ? 1 : 0;
3388 return getOperand(getWaitOperands().size() + numOptional + i);
3393 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
3396 void EnterDataOp::addAsyncOnly(
3398 assert(effectiveDeviceTypes.empty());
3399 assert(!getAsyncAttr());
3400 assert(!getAsyncOperand());
3405 void EnterDataOp::addAsyncOperand(
3408 assert(effectiveDeviceTypes.empty());
3409 assert(!getAsyncAttr());
3410 assert(!getAsyncOperand());
3412 getAsyncOperandMutable().append(newValue);
3415 void EnterDataOp::addWaitOnly(
MLIRContext *context,
3417 assert(effectiveDeviceTypes.empty());
3418 assert(!getWaitAttr());
3419 assert(getWaitOperands().empty());
3420 assert(!getWaitDevnum());
3425 void EnterDataOp::addWaitOperands(
3428 assert(effectiveDeviceTypes.empty());
3429 assert(!getWaitAttr());
3430 assert(getWaitOperands().empty());
3431 assert(!getWaitDevnum());
3436 getWaitDevnumMutable().append(newValues.front());
3437 newValues = newValues.drop_front();
3440 getWaitOperandsMutable().append(newValues);
3459 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3466 if (
Value writeVal = op.getWriteOpVal()) {
3476 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3482 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3483 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3485 return dyn_cast<AtomicReadOp>(getSecondOp());
3488 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3489 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3491 return dyn_cast<AtomicWriteOp>(getSecondOp());
3494 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3495 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3497 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3500 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
3506 template <
typename Op>
3507 static LogicalResult
3509 bool requireAtLeastOneOperand =
true) {
3510 if (operands.empty() && requireAtLeastOneOperand)
3513 "at least one operand must appear on the declare operation");
3516 if (!mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3517 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3518 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3519 operand.getDefiningOp()))
3521 "expect valid declare data entry operation or acc.getdeviceptr "
3525 assert(var &&
"declare operands can only be data entry operations which "
3528 std::optional<mlir::acc::DataClause> dataClauseOptional{
3530 assert(dataClauseOptional.has_value() &&
3531 "declare operands can only be data entry operations which must have "
3533 (void)dataClauseOptional;
3567 acc::DeviceType dtype) {
3568 unsigned parallelism = 0;
3569 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
3570 parallelism += op.hasWorker(dtype) ? 1 : 0;
3571 parallelism += op.hasVector(dtype) ? 1 : 0;
3572 parallelism += op.hasSeq(dtype) ? 1 : 0;
3577 unsigned baseParallelism =
3580 if (baseParallelism > 1)
3581 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
3582 "be present at the same time";
3584 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3586 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
3591 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3592 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
3593 "be present at the same time";
3600 mlir::ArrayAttr &bindIdName,
3601 mlir::ArrayAttr &bindStrName,
3602 mlir::ArrayAttr &deviceIdTypes,
3603 mlir::ArrayAttr &deviceStrTypes) {
3610 mlir::Attribute newAttr;
3611 bool isSymbolRefAttr;
3612 auto parseResult = parser.parseAttribute(newAttr);
3613 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
3614 bindIdNameAttrs.push_back(symbolRefAttr);
3615 isSymbolRefAttr = true;
3616 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
3617 bindStrNameAttrs.push_back(stringAttr);
3618 isSymbolRefAttr =
false;
3623 if (isSymbolRefAttr) {
3624 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3625 parser.getContext(), mlir::acc::DeviceType::None));
3627 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3628 parser.getContext(), mlir::acc::DeviceType::None));
3631 if (isSymbolRefAttr) {
3632 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
3633 parser.parseRSquare())
3636 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
3637 parser.parseRSquare())
3645 bindIdName =
ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
3646 bindStrName =
ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
3647 deviceIdTypes =
ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
3648 deviceStrTypes =
ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
3654 std::optional<mlir::ArrayAttr> bindIdName,
3655 std::optional<mlir::ArrayAttr> bindStrName,
3656 std::optional<mlir::ArrayAttr> deviceIdTypes,
3657 std::optional<mlir::ArrayAttr> deviceStrTypes) {
3664 allBindNames.append(bindIdName->begin(), bindIdName->end());
3665 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
3670 allBindNames.append(bindStrName->begin(), bindStrName->end());
3671 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
3675 if (!allBindNames.empty())
3676 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
3677 [&](
const auto &pair) {
3678 p << std::get<0>(pair);
3684 mlir::ArrayAttr &gang,
3685 mlir::ArrayAttr &gangDim,
3686 mlir::ArrayAttr &gangDimDeviceTypes) {
3689 gangDimDeviceTypeAttrs;
3690 bool needCommaBeforeOperands =
false;
3703 if (parser.parseAttribute(gangAttrs.emplace_back()))
3710 needCommaBeforeOperands =
true;
3717 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
3718 parser.parseColon() ||
3719 parser.parseAttribute(gangDimAttrs.emplace_back()))
3721 if (succeeded(parser.parseOptionalLSquare())) {
3722 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
3723 parser.parseRSquare())
3726 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3727 parser.getContext(), mlir::acc::DeviceType::None));
3733 if (
failed(parser.parseRParen()))
3738 gangDimDeviceTypes =
3745 std::optional<mlir::ArrayAttr> gang,
3746 std::optional<mlir::ArrayAttr> gangDim,
3747 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
3750 gang->size() == 1) {
3751 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
3764 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
3765 [&](
const auto &pair) {
3766 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
3767 p << std::get<0>(pair);
3775 mlir::ArrayAttr &deviceTypes) {
3788 if (parser.parseAttribute(attributes.emplace_back()))
3802 std::optional<mlir::ArrayAttr> deviceTypes) {
3805 auto deviceTypeAttr =
3806 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
3816 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3824 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
3830 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
3836 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
3840 std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
3841 RoutineOp::getBindNameValue() {
3845 std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
3846 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3849 return std::nullopt;
3852 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
3853 auto attr = (*getBindIdName())[*pos];
3854 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
3855 assert(symbolRefAttr &&
"expected SymbolRef");
3856 return symbolRefAttr;
3859 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
3860 auto attr = (*getBindStrName())[*pos];
3861 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3862 assert(stringAttr &&
"expected String");
3866 return std::nullopt;
3871 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
3875 std::optional<int64_t> RoutineOp::getGangDimValue() {
3879 std::optional<int64_t>
3880 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
3882 return std::nullopt;
3883 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
3884 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
3885 return intAttr.getInt();
3887 return std::nullopt;
3898 return emitOpError(
"cannot be nested in a compute operation");
3902 void acc::InitOp::addDeviceType(
MLIRContext *context,
3903 mlir::acc::DeviceType deviceType) {
3905 if (getDeviceTypesAttr())
3906 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3920 return emitOpError(
"cannot be nested in a compute operation");
3924 void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
3925 mlir::acc::DeviceType deviceType) {
3927 if (getDeviceTypesAttr())
3928 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3942 return emitOpError(
"cannot be nested in a compute operation");
3943 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
3944 return emitOpError(
"at least one default_async, device_num, or device_type "
3945 "operand must appear");
3955 if (getDataClauseOperands().empty())
3956 return emitError(
"at least one value must be present in dataOperands");
3959 getAsyncOperandsDeviceTypeAttr(),
3964 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3965 getWaitOperandsDeviceTypeAttr(),
"wait")))
3968 if (
failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*
this)))
3971 for (
mlir::Value operand : getDataClauseOperands())
3972 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3973 operand.getDefiningOp()))
3974 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3980 unsigned UpdateOp::getNumDataOperands() {
3981 return getDataClauseOperands().size();
3984 Value UpdateOp::getDataOperand(
unsigned i) {
3986 numOptional += getIfCond() ? 1 : 0;
3987 return getOperand(getWaitOperands().size() + numOptional + i);
3992 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
3995 bool UpdateOp::hasAsyncOnly() {
3999 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4007 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4017 bool UpdateOp::hasWaitOnly() {
4021 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4030 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4032 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4033 getHasWaitDevnum(), deviceType);
4040 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4042 getWaitOperandsSegments(), getHasWaitDevnum(),
4048 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4049 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4052 void UpdateOp::addAsyncOperand(
4055 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4056 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4057 getAsyncOperandsMutable()));
4062 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4063 effectiveDeviceTypes));
4066 void UpdateOp::addWaitOperands(
4071 if (getWaitOperandsSegments())
4072 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4074 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4075 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4076 getWaitOperandsMutable(), segments));
4077 setWaitOperandsSegments(segments);
4080 if (getHasWaitDevnumAttr())
4081 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4084 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4096 if (getAsyncOperand() && getAsync())
4097 return emitError(
"async attribute cannot appear with asyncOperand");
4099 if (getWaitDevnum() && getWaitOperands().empty())
4100 return emitError(
"wait_devnum cannot appear without waitOperands");
4105 #define GET_OP_CLASSES
4106 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4108 #define GET_ATTRDEF_CLASSES
4109 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4111 #define GET_TYPEDEF_CLASSES
4112 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4123 .Case<ACC_DATA_ENTRY_OPS>(
4124 [&](
auto entry) {
return entry.getVarPtr(); })
4125 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4126 [&](
auto exit) {
return exit.getVarPtr(); })
4144 [&](
auto entry) {
return entry.getVarType(); })
4145 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4146 [&](
auto exit) {
return exit.getVarType(); })
4156 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4157 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
4167 [&](
auto dataClause) {
return dataClause.getAccVar(); })
4176 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
4186 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4188 dataClause.getBounds().begin(), dataClause.getBounds().end());
4200 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4202 dataClause.getAsyncOperands().begin(),
4203 dataClause.getAsyncOperands().end());
4214 return dataClause.getAsyncOperandsDeviceTypeAttr();
4222 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
4229 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
4236 std::optional<mlir::acc::DataClause>
4241 .Case<ACC_DATA_ENTRY_OPS>(
4242 [&](
auto entry) {
return entry.getDataClause(); })
4250 [&](
auto entry) {
return entry.getImplicit(); })
4259 [&](
auto entry) {
return entry.getDataClauseOperands(); })
4261 return dataOperands;
4269 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
4271 return dataOperands;
4277 if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) {
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Operation * getParentOp()
Return the parent operation this region is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
mlir::Operation * getEnclosingComputeOp(mlir::Region ®ion)
Used to obtain the enclosing compute construct operation that contains the provided region.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.