21 #include "llvm/ADT/SmallSet.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/LogicalResult.h"
29 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
30 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
31 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
32 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
33 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
37 static bool isScalarLikeType(
Type type) {
41 struct MemRefPointerLikeModel
42 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
45 return cast<MemRefType>(pointer).getElementType();
47 mlir::acc::VariableTypeCategory
50 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
51 return mappableTy.getTypeCategory(varPtr);
53 auto memrefTy = cast<MemRefType>(pointer);
54 if (!memrefTy.hasRank()) {
57 return mlir::acc::VariableTypeCategory::uncategorized;
60 if (memrefTy.getRank() == 0) {
61 if (isScalarLikeType(memrefTy.getElementType())) {
62 return mlir::acc::VariableTypeCategory::scalar;
66 return mlir::acc::VariableTypeCategory::uncategorized;
70 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
71 return mlir::acc::VariableTypeCategory::array;
75 struct LLVMPointerPointerLikeModel
76 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77 LLVM::LLVMPointerType> {
85 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
86 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
89 if (existingDeviceTypes)
90 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
92 if (newDeviceTypes.empty())
93 deviceTypes.push_back(
96 for (DeviceType dt : newDeviceTypes)
108 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
109 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
114 if (existingDeviceTypes)
115 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
117 if (newDeviceTypes.empty()) {
118 argCollection.
append(arguments);
119 segments.push_back(arguments.size());
120 deviceTypes.push_back(
124 for (DeviceType dt : newDeviceTypes) {
125 argCollection.
append(arguments);
126 segments.push_back(arguments.size());
134 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
135 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
139 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
140 newDeviceTypes, arguments,
141 argCollection, segments);
149 void OpenACCDialect::initialize() {
152 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
155 #define GET_ATTRDEF_LIST
156 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
159 #define GET_TYPEDEF_LIST
160 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
166 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
167 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
176 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
180 mlir::acc::DeviceType deviceType) {
184 for (
auto attr : *arrayAttr) {
185 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
186 if (deviceTypeAttr.getValue() == deviceType)
194 std::optional<mlir::ArrayAttr> deviceTypes) {
199 llvm::interleaveComma(*deviceTypes, p,
205 mlir::acc::DeviceType deviceType) {
206 unsigned segmentIdx = 0;
207 for (
auto attr : segments) {
208 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
209 if (deviceTypeAttr.getValue() == deviceType)
210 return std::make_optional(segmentIdx);
220 mlir::acc::DeviceType deviceType) {
222 return range.take_front(0);
223 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
224 int32_t nbOperandsBefore = 0;
225 for (
unsigned i = 0; i < *pos; ++i)
226 nbOperandsBefore += (*segments)[i];
227 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
229 return range.take_front(0);
236 std::optional<mlir::ArrayAttr> hasWaitDevnum,
237 mlir::acc::DeviceType deviceType) {
240 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
241 if (hasWaitDevnum->getValue()[*pos])
252 std::optional<mlir::ArrayAttr> hasWaitDevnum,
253 mlir::acc::DeviceType deviceType) {
258 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
259 if (hasWaitDevnum && *hasWaitDevnum) {
260 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
261 if (boolAttr.getValue())
262 return range.drop_front(1);
268 template <
typename Op>
270 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
272 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
277 op.hasAsyncOnly(dtype))
279 "asyncOnly attribute cannot appear with asyncOperand");
284 op.hasWaitOnly(dtype))
285 return op.
emitError(
"wait attribute cannot appear with waitOperands");
290 template <
typename Op>
293 return op.
emitError(
"must have var operand");
296 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
297 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
298 return op.
emitError(
"var must be mappable or pointer-like");
301 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
302 op.getVarType() == op.getVar().getType())
303 return op.
emitError(
"varType must capture the element type of var");
308 template <
typename Op>
310 if (op.getVar().getType() != op.getAccVar().getType())
311 return op.
emitError(
"input and output types must match");
316 template <
typename Op>
318 if (op.getModifiers() != acc::DataClauseModifier::none)
319 return op.
emitError(
"no data clause modifiers are allowed");
323 template <
typename Op>
326 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
328 "invalid data clause modifiers: " +
329 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
351 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
382 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
394 mlir::TypeAttr &varTypeAttr) {
411 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
413 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
422 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
430 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
431 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
433 if (typeToCheckAgainst != varType) {
444 auto extent = getExtent();
445 auto upperbound = getUpperbound();
446 if (!extent && !upperbound)
447 return emitError(
"expected extent or upperbound.");
457 "data clause associated with private operation must match its intent");
470 return emitError(
"data clause associated with firstprivate operation must "
484 return emitError(
"data clause associated with reduction operation must "
498 return emitError(
"data clause associated with deviceptr operation must "
515 "data clause associated with present operation must match its intent");
530 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
535 "data clause associated with copyin operation must match its intent"
536 " or specify original clause this operation was decomposed from");
542 acc::DataClauseModifier::always |
543 acc::DataClauseModifier::capture)))
548 bool acc::CopyinOp::isCopyinReadonly() {
549 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
550 acc::bitEnumContainsAny(getModifiers(),
551 acc::DataClauseModifier::readonly);
564 "data clause associated with create operation must match its intent"
565 " or specify original clause this operation was decomposed from");
573 acc::DataClauseModifier::always |
574 acc::DataClauseModifier::capture)))
579 bool acc::CreateOp::isCreateZero() {
581 return getDataClause() == acc::DataClause::acc_create_zero ||
583 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
591 return emitError(
"data clause associated with no_create operation must "
608 "data clause associated with attach operation must match its intent");
623 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
624 return emitError(
"data clause associated with device_resident operation "
625 "must match its intent");
642 "data clause associated with link operation must match its intent");
662 "data clause associated with copyout operation must match its intent"
663 " or specify original clause this operation was decomposed from");
665 return emitError(
"must have both host and device pointers");
671 acc::DataClauseModifier::always |
672 acc::DataClauseModifier::capture)))
677 bool acc::CopyoutOp::isCopyoutZero() {
678 return getDataClause() == acc::DataClause::acc_copyout_zero ||
679 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
694 getDataClause() != acc::DataClause::acc_declare_device_resident &&
697 "data clause associated with delete operation must match its intent"
698 " or specify original clause this operation was decomposed from");
700 return emitError(
"must have device pointer");
704 acc::DataClauseModifier::readonly |
705 acc::DataClauseModifier::always |
706 acc::DataClauseModifier::capture)))
719 "data clause associated with detach operation must match its intent"
720 " or specify original clause this operation was decomposed from");
722 return emitError(
"must have device pointer");
736 "data clause associated with host operation must match its intent"
737 " or specify original clause this operation was decomposed from");
739 return emitError(
"must have both host and device pointers");
756 "data clause associated with device operation must match its intent"
757 " or specify original clause this operation was decomposed from");
774 "data clause associated with use_device operation must match its intent"
775 " or specify original clause this operation was decomposed from");
793 "data clause associated with cache operation must match its intent"
794 " or specify original clause this operation was decomposed from");
804 bool acc::CacheOp::isCacheReadonly() {
805 return getDataClause() == acc::DataClause::acc_cache_readonly ||
806 acc::bitEnumContainsAny(getModifiers(),
807 acc::DataClauseModifier::readonly);
810 template <
typename StructureOp>
812 unsigned nRegions = 1) {
815 for (
unsigned i = 0; i < nRegions; ++i)
816 regions.push_back(state.addRegion());
818 for (
Region *region : regions)
826 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
833 template <
typename OpTy>
837 LogicalResult matchAndRewrite(OpTy op,
840 Value ifCond = op.getIfCond();
844 IntegerAttr constAttr;
847 if (constAttr.getInt())
848 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
860 assert(region.
hasOneBlock() &&
"expected single-block region");
872 template <
typename OpTy>
873 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
876 LogicalResult matchAndRewrite(OpTy op,
879 Value ifCond = op.getIfCond();
883 IntegerAttr constAttr;
886 if (constAttr.getInt())
887 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
902 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
904 if (optional && region.
empty())
908 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
912 return op->
emitOpError() <<
"expects " << regionName
915 << regionType <<
" type";
918 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
919 if (yieldOp.getOperands().size() != 1 ||
920 yieldOp.getOperands().getTypes()[0] != type)
921 return op->
emitOpError() <<
"expects " << regionName
923 "yield a value of the "
924 << regionType <<
" type";
930 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
932 "privatization",
"init",
getType(),
936 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
946 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
948 "privatization",
"init",
getType(),
952 if (getCopyRegion().empty())
953 return emitOpError() <<
"expects non-empty copy region";
958 return emitOpError() <<
"expects copy region with two arguments of the "
959 "privatization type";
961 if (getDestroyRegion().empty())
965 "privatization",
"destroy",
976 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
982 if (getCombinerRegion().empty())
983 return emitOpError() <<
"expects non-empty combiner region";
985 Block &reductionBlock = getCombinerRegion().
front();
989 return emitOpError() <<
"expects combiner region with the first two "
990 <<
"arguments of the reduction type";
992 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
993 if (yieldOp.getOperands().size() != 1 ||
994 yieldOp.getOperands().getTypes()[0] !=
getType())
995 return emitOpError() <<
"expects combiner region to yield a value "
996 "of the reduction type";
1012 if (parser.parseAttribute(attributes.emplace_back()) ||
1013 parser.parseArrow() ||
1014 parser.parseOperand(operands.emplace_back()) ||
1015 parser.parseColonType(types.emplace_back()))
1029 std::optional<mlir::ArrayAttr> attributes) {
1030 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
1031 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
1032 << std::get<1>(it).getType();
1041 template <
typename Op>
1045 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1046 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1047 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1048 operand.getDefiningOp()))
1050 "expect data entry/exit operation or acc.getdeviceptr "
1055 template <
typename Op>
1056 static LogicalResult
1059 llvm::StringRef symbolName,
bool checkOperandType =
true) {
1060 if (!operands.empty()) {
1061 if (!attributes || attributes->size() != operands.size())
1063 <<
"expected as many " << symbolName <<
" symbol reference as "
1064 << operandName <<
" operands";
1068 <<
"unexpected " << symbolName <<
" symbol reference";
1073 for (
auto args : llvm::zip(operands, *attributes)) {
1076 if (!set.insert(operand).second)
1078 << operandName <<
" operand appears more than once";
1081 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1082 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1085 <<
"expected symbol reference " << symbolRef <<
" to point to a "
1086 << operandName <<
" declaration";
1088 if (checkOperandType && decl.getType() && decl.getType() != varType)
1089 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
1090 <<
") to be the same type as " << operandName
1091 <<
" declaration (" << decl.getType() <<
")";
1097 unsigned ParallelOp::getNumDataOperands() {
1098 return getReductionOperands().size() + getPrivateOperands().size() +
1099 getFirstprivateOperands().size() + getDataClauseOperands().size();
1102 Value ParallelOp::getDataOperand(
unsigned i) {
1104 numOptional += getNumGangs().size();
1105 numOptional += getNumWorkers().size();
1106 numOptional += getVectorLength().size();
1107 numOptional += getIfCond() ? 1 : 0;
1108 numOptional += getSelfCond() ? 1 : 0;
1109 return getOperand(getWaitOperands().size() + numOptional + i);
1112 template <
typename Op>
1114 ArrayAttr deviceTypes,
1115 llvm::StringRef keyword) {
1116 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1117 return op.
emitOpError() << keyword <<
" operands count must match "
1118 << keyword <<
" device_type count";
1122 template <
typename Op>
1125 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1126 std::size_t numOperandsInSegments = 0;
1127 std::size_t nbOfSegments = 0;
1130 for (
auto segCount : segments.
asArrayRef()) {
1131 if (maxInSegment != 0 && segCount > maxInSegment)
1132 return op.
emitOpError() << keyword <<
" expects a maximum of "
1133 << maxInSegment <<
" values per segment";
1134 numOperandsInSegments += segCount;
1139 if ((numOperandsInSegments != operands.size()) ||
1140 (!deviceTypes && !operands.empty()))
1142 << keyword <<
" operand count does not match count in segments";
1143 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1145 << keyword <<
" segment count does not match device_type count";
1150 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1151 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
1152 "privatizations",
false)))
1154 if (
failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1155 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1156 "firstprivate",
"firstprivatizations",
false)))
1158 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1159 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1160 "reductions",
false)))
1164 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1165 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1169 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1170 getWaitOperandsDeviceTypeAttr(),
"wait")))
1174 getNumWorkersDeviceTypeAttr(),
1179 getVectorLengthDeviceTypeAttr(),
1184 getAsyncOperandsDeviceTypeAttr(),
1188 if (
failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*
this)))
1191 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
1197 mlir::acc::DeviceType deviceType) {
1200 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1205 bool acc::ParallelOp::hasAsyncOnly() {
1209 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1217 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1222 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1227 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1232 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1237 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1239 getVectorLength(), deviceType);
1247 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1249 getNumGangsSegments(), deviceType);
1252 bool acc::ParallelOp::hasWaitOnly() {
1256 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1265 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1267 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1268 getHasWaitDevnum(), deviceType);
1275 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1277 getWaitOperandsSegments(), getHasWaitDevnum(),
1293 odsBuilder, odsState, asyncOperands,
nullptr,
1294 nullptr, waitOperands,
nullptr,
1296 nullptr, numGangs,
nullptr,
1297 nullptr, numWorkers,
1298 nullptr, vectorLength,
1299 nullptr, ifCond, selfCond,
1300 nullptr, reductionOperands,
nullptr,
1301 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
1302 nullptr, dataClauseOperands,
1306 void acc::ParallelOp::addNumWorkersOperand(
1309 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1310 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1311 getNumWorkersMutable()));
1313 void acc::ParallelOp::addVectorLengthOperand(
1316 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1317 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1318 getVectorLengthMutable()));
1321 void acc::ParallelOp::addAsyncOnly(
1323 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1324 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1327 void acc::ParallelOp::addAsyncOperand(
1330 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1331 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1332 getAsyncOperandsMutable()));
1335 void acc::ParallelOp::addNumGangsOperands(
1339 if (getNumGangsSegments())
1340 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1342 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1343 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1344 getNumGangsMutable(), segments));
1346 setNumGangsSegments(segments);
1348 void acc::ParallelOp::addWaitOnly(
1350 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1351 effectiveDeviceTypes));
1353 void acc::ParallelOp::addWaitOperands(
1358 if (getWaitOperandsSegments())
1359 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1361 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1362 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1363 getWaitOperandsMutable(), segments));
1364 setWaitOperandsSegments(segments);
1367 if (getHasWaitDevnumAttr())
1368 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1371 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1376 void acc::ParallelOp::addPrivatization(
MLIRContext *context,
1377 mlir::acc::PrivateOp op,
1378 mlir::acc::PrivateRecipeOp recipe) {
1379 getPrivateOperandsMutable().append(op.getResult());
1383 if (getPrivatizationRecipesAttr())
1384 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
1391 void acc::ParallelOp::addFirstPrivatization(
1392 MLIRContext *context, mlir::acc::FirstprivateOp op,
1393 mlir::acc::FirstprivateRecipeOp recipe) {
1394 getFirstprivateOperandsMutable().append(op.getResult());
1398 if (getFirstprivatizationRecipesAttr())
1399 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
1406 void acc::ParallelOp::addReduction(
MLIRContext *context,
1407 mlir::acc::ReductionOp op,
1408 mlir::acc::ReductionRecipeOp recipe) {
1409 getReductionOperandsMutable().append(op.getResult());
1413 if (getReductionRecipesAttr())
1414 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
1433 int32_t crtOperandsSize = operands.size();
1436 if (parser.parseOperand(operands.emplace_back()) ||
1437 parser.parseColonType(types.emplace_back()))
1442 seg.push_back(operands.size() - crtOperandsSize);
1466 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1468 p <<
" [" << attr <<
"]";
1473 std::optional<mlir::ArrayAttr> deviceTypes,
1474 std::optional<mlir::DenseI32ArrayAttr> segments) {
1476 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1478 llvm::interleaveComma(
1479 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1480 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1500 int32_t crtOperandsSize = operands.size();
1504 if (parser.parseOperand(operands.emplace_back()) ||
1505 parser.parseColonType(types.emplace_back()))
1511 seg.push_back(operands.size() - crtOperandsSize);
1537 std::optional<mlir::DenseI32ArrayAttr> segments) {
1539 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1541 llvm::interleaveComma(
1542 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1543 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1556 mlir::ArrayAttr &keywordOnly) {
1560 bool needCommaBeforeOperands =
false;
1573 if (parser.parseAttribute(keywordAttrs.emplace_back()))
1580 needCommaBeforeOperands =
true;
1590 int32_t crtOperandsSize = operands.size();
1602 if (parser.parseOperand(operands.emplace_back()) ||
1603 parser.parseColonType(types.emplace_back()))
1609 seg.push_back(operands.size() - crtOperandsSize);
1638 if (attrs->size() != 1)
1640 if (
auto deviceTypeAttr =
1641 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
1648 std::optional<mlir::ArrayAttr> deviceTypes,
1649 std::optional<mlir::DenseI32ArrayAttr> segments,
1650 std::optional<mlir::ArrayAttr> hasDevNum,
1651 std::optional<mlir::ArrayAttr> keywordOnly) {
1664 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1666 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
1667 if (boolAttr && boolAttr.getValue())
1669 llvm::interleaveComma(
1670 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1671 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1688 if (parser.parseOperand(operands.emplace_back()) ||
1689 parser.parseColonType(types.emplace_back()))
1691 if (succeeded(parser.parseOptionalLSquare())) {
1692 if (parser.parseAttribute(attributes.emplace_back()) ||
1693 parser.parseRSquare())
1696 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1697 parser.getContext(), mlir::acc::DeviceType::None));
1711 std::optional<mlir::ArrayAttr> deviceTypes) {
1714 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
1715 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
1724 mlir::ArrayAttr &keywordOnlyDeviceType) {
1727 bool needCommaBeforeOperands =
false;
1733 keywordOnlyDeviceType =
1742 if (parser.parseAttribute(
1743 keywordOnlyDeviceTypeAttributes.emplace_back()))
1750 needCommaBeforeOperands =
true;
1758 if (parser.parseOperand(operands.emplace_back()) ||
1759 parser.parseColonType(types.emplace_back()))
1761 if (succeeded(parser.parseOptionalLSquare())) {
1762 if (parser.parseAttribute(attributes.emplace_back()) ||
1763 parser.parseRSquare())
1766 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1767 parser.getContext(), mlir::acc::DeviceType::None));
1773 if (
failed(parser.parseRParen()))
1785 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
1787 if (operands.begin() == operands.end() &&
1803 std::optional<OpAsmParser::UnresolvedOperand> &operand,
1804 mlir::Type &operandType, mlir::UnitAttr &attr) {
1827 std::optional<mlir::Value> operand,
1829 mlir::UnitAttr attr) {
1851 if (parser.parseOperand(operands.emplace_back()))
1859 if (parser.parseType(types.emplace_back()))
1874 mlir::UnitAttr attr) {
1879 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
1881 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
1887 mlir::acc::CombinedConstructsTypeAttr &attr) {
1890 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
1893 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
1896 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
1899 "expected compute construct name");
1907 mlir::acc::CombinedConstructsTypeAttr attr) {
1909 switch (attr.getValue()) {
1910 case mlir::acc::CombinedConstructsType::KernelsLoop:
1913 case mlir::acc::CombinedConstructsType::ParallelLoop:
1916 case mlir::acc::CombinedConstructsType::SerialLoop:
1927 unsigned SerialOp::getNumDataOperands() {
1928 return getReductionOperands().size() + getPrivateOperands().size() +
1929 getFirstprivateOperands().size() + getDataClauseOperands().size();
1932 Value SerialOp::getDataOperand(
unsigned i) {
1934 numOptional += getIfCond() ? 1 : 0;
1935 numOptional += getSelfCond() ? 1 : 0;
1936 return getOperand(getWaitOperands().size() + numOptional + i);
1939 bool acc::SerialOp::hasAsyncOnly() {
1943 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1951 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1956 bool acc::SerialOp::hasWaitOnly() {
1960 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1969 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1971 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1972 getHasWaitDevnum(), deviceType);
1979 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1981 getWaitOperandsSegments(), getHasWaitDevnum(),
1986 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1987 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
1988 "privatizations",
false)))
1990 if (
failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1991 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1992 "firstprivate",
"firstprivatizations",
false)))
1994 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1995 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1996 "reductions",
false)))
2000 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2001 getWaitOperandsDeviceTypeAttr(),
"wait")))
2005 getAsyncOperandsDeviceTypeAttr(),
2009 if (
failed(checkWaitAndAsyncConflict<acc::SerialOp>(*
this)))
2012 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
2015 void acc::SerialOp::addAsyncOnly(
2017 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2018 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2021 void acc::SerialOp::addAsyncOperand(
2024 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2025 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2026 getAsyncOperandsMutable()));
2029 void acc::SerialOp::addWaitOnly(
2031 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2032 effectiveDeviceTypes));
2034 void acc::SerialOp::addWaitOperands(
2039 if (getWaitOperandsSegments())
2040 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2042 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2043 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2044 getWaitOperandsMutable(), segments));
2045 setWaitOperandsSegments(segments);
2048 if (getHasWaitDevnumAttr())
2049 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2052 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2057 void acc::SerialOp::addPrivatization(
MLIRContext *context,
2058 mlir::acc::PrivateOp op,
2059 mlir::acc::PrivateRecipeOp recipe) {
2060 getPrivateOperandsMutable().append(op.getResult());
2064 if (getPrivatizationRecipesAttr())
2065 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
2072 void acc::SerialOp::addFirstPrivatization(
2073 MLIRContext *context, mlir::acc::FirstprivateOp op,
2074 mlir::acc::FirstprivateRecipeOp recipe) {
2075 getFirstprivateOperandsMutable().append(op.getResult());
2079 if (getFirstprivatizationRecipesAttr())
2080 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
2087 void acc::SerialOp::addReduction(
MLIRContext *context,
2088 mlir::acc::ReductionOp op,
2089 mlir::acc::ReductionRecipeOp recipe) {
2090 getReductionOperandsMutable().append(op.getResult());
2094 if (getReductionRecipesAttr())
2095 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
2106 unsigned KernelsOp::getNumDataOperands() {
2107 return getDataClauseOperands().size();
2110 Value KernelsOp::getDataOperand(
unsigned i) {
2112 numOptional += getWaitOperands().size();
2113 numOptional += getNumGangs().size();
2114 numOptional += getNumWorkers().size();
2115 numOptional += getVectorLength().size();
2116 numOptional += getIfCond() ? 1 : 0;
2117 numOptional += getSelfCond() ? 1 : 0;
2118 return getOperand(numOptional + i);
2121 bool acc::KernelsOp::hasAsyncOnly() {
2125 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2133 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2138 mlir::Value acc::KernelsOp::getNumWorkersValue() {
2143 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2148 mlir::Value acc::KernelsOp::getVectorLengthValue() {
2153 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2155 getVectorLength(), deviceType);
2163 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2165 getNumGangsSegments(), deviceType);
2168 bool acc::KernelsOp::hasWaitOnly() {
2172 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2181 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2183 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2184 getHasWaitDevnum(), deviceType);
2191 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2193 getWaitOperandsSegments(), getHasWaitDevnum(),
2199 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2200 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2204 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2205 getWaitOperandsDeviceTypeAttr(),
"wait")))
2209 getNumWorkersDeviceTypeAttr(),
2214 getVectorLengthDeviceTypeAttr(),
2219 getAsyncOperandsDeviceTypeAttr(),
2223 if (
failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*
this)))
2226 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
2229 void acc::KernelsOp::addNumWorkersOperand(
2232 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2233 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2234 getNumWorkersMutable()));
2237 void acc::KernelsOp::addVectorLengthOperand(
2240 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2241 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2242 getVectorLengthMutable()));
2244 void acc::KernelsOp::addAsyncOnly(
2246 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2247 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2250 void acc::KernelsOp::addAsyncOperand(
2253 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2254 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2255 getAsyncOperandsMutable()));
2258 void acc::KernelsOp::addNumGangsOperands(
2262 if (getNumGangsSegmentsAttr())
2263 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2265 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2266 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2267 getNumGangsMutable(), segments));
2269 setNumGangsSegments(segments);
2272 void acc::KernelsOp::addWaitOnly(
2274 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2275 effectiveDeviceTypes));
2277 void acc::KernelsOp::addWaitOperands(
2282 if (getWaitOperandsSegments())
2283 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2285 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2286 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2287 getWaitOperandsMutable(), segments));
2288 setWaitOperandsSegments(segments);
2291 if (getHasWaitDevnumAttr())
2292 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2295 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2305 if (getDataClauseOperands().empty())
2306 return emitError(
"at least one operand must appear on the host_data "
2309 for (
mlir::Value operand : getDataClauseOperands())
2310 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2311 return emitError(
"expect data entry operation as defining op");
2317 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2329 bool &needCommaBetweenValues,
bool &newValue) {
2336 attributes.push_back(gangArgType);
2337 needCommaBetweenValues =
true;
2348 mlir::ArrayAttr &gangOnlyDeviceType) {
2353 bool needCommaBetweenValues =
false;
2354 bool needCommaBeforeOperands =
false;
2360 gangOnlyDeviceType =
2369 if (parser.parseAttribute(
2370 gangOnlyDeviceTypeAttributes.emplace_back()))
2377 needCommaBeforeOperands =
true;
2381 mlir::acc::GangArgType::Num);
2383 mlir::acc::GangArgType::Dim);
2385 parser.
getContext(), mlir::acc::GangArgType::Static);
2388 if (needCommaBeforeOperands) {
2389 needCommaBeforeOperands =
false;
2396 int32_t crtOperandsSize = gangOperands.size();
2398 bool newValue =
false;
2399 bool needValue =
false;
2400 if (needCommaBetweenValues) {
2408 gangOperands, gangOperandsType,
2409 gangArgTypeAttributes, argNum,
2410 needCommaBetweenValues, newValue)))
2413 gangOperands, gangOperandsType,
2414 gangArgTypeAttributes, argDim,
2415 needCommaBetweenValues, newValue)))
2418 gangOperands, gangOperandsType,
2419 gangArgTypeAttributes, argStatic,
2420 needCommaBetweenValues, newValue)))
2423 if (!newValue && needValue) {
2425 "new value expected after comma");
2433 if (gangOperands.empty())
2436 "expect at least one of num, dim or static values");
2442 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
2450 seg.push_back(gangOperands.size() - crtOperandsSize);
2458 gangArgTypeAttributes.end());
2463 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2472 std::optional<mlir::ArrayAttr> gangArgTypes,
2473 std::optional<mlir::ArrayAttr> deviceTypes,
2474 std::optional<mlir::DenseI32ArrayAttr> segments,
2475 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2477 if (operands.begin() == operands.end() &&
2492 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2494 llvm::interleaveComma(
2495 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2496 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2497 (*gangArgTypes)[opIdx]);
2498 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2499 p << LoopOp::getGangNumKeyword();
2500 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2501 p << LoopOp::getGangDimKeyword();
2502 else if (gangArgTypeAttr.getValue() ==
2503 mlir::acc::GangArgType::Static)
2504 p << LoopOp::getGangStaticKeyword();
2505 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
2516 std::optional<mlir::ArrayAttr> segments,
2517 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2520 for (
auto attr : *segments) {
2521 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2522 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2530 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2533 for (
auto attr : deviceTypes) {
2534 auto deviceTypeAttr =
2535 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2536 if (!deviceTypeAttr)
2538 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2545 if (getUpperbound().size() != getStep().size())
2546 return emitError() <<
"number of upperbounds expected to be the same as "
2549 if (getUpperbound().size() != getLowerbound().size())
2550 return emitError() <<
"number of upperbounds expected to be the same as "
2551 "number of lowerbounds";
2553 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2554 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2555 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
2556 <<
" as upperbound size";
2559 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2560 return emitOpError() <<
"collapse device_type attr must be define when"
2561 <<
" collapse attr is present";
2563 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2564 getCollapseAttr().getValue().size() !=
2565 getCollapseDeviceTypeAttr().getValue().size())
2566 return emitOpError() <<
"collapse attribute count must match collapse"
2567 <<
" device_type count";
2569 return emitOpError()
2570 <<
"duplicate device_type found in collapseDeviceType attribute";
2573 if (!getGangOperands().empty()) {
2574 if (!getGangOperandsArgType())
2575 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
2576 <<
" when gang operands are present";
2578 if (getGangOperands().size() !=
2579 getGangOperandsArgTypeAttr().getValue().size())
2580 return emitOpError() <<
"gangOperandsArgType attribute count must match"
2581 <<
" gangOperands count";
2584 return emitOpError() <<
"duplicate device_type found in gang attribute";
2587 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
2588 getGangOperandsDeviceTypeAttr(),
"gang")))
2593 return emitOpError() <<
"duplicate device_type found in worker attribute";
2595 return emitOpError() <<
"duplicate device_type found in "
2596 "workerNumOperandsDeviceType attribute";
2598 getWorkerNumOperandsDeviceTypeAttr(),
2604 return emitOpError() <<
"duplicate device_type found in vector attribute";
2606 return emitOpError() <<
"duplicate device_type found in "
2607 "vectorOperandsDeviceType attribute";
2609 getVectorOperandsDeviceTypeAttr(),
2614 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
2615 getTileOperandsDeviceTypeAttr(),
"tile")))
2619 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2623 return emitError() <<
"only one of auto, independent, seq can be present "
2629 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
2632 bool hasDefaultSeq =
2634 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2637 bool hasDefaultIndependent =
2638 getIndependentAttr()
2640 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2643 bool hasDefaultAuto =
2645 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2648 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
2650 <<
"at least one of auto, independent, seq must be present";
2655 for (
auto attr : getSeqAttr()) {
2656 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2657 if (hasVector(deviceTypeAttr.getValue()) ||
2658 getVectorValue(deviceTypeAttr.getValue()) ||
2659 hasWorker(deviceTypeAttr.getValue()) ||
2660 getWorkerValue(deviceTypeAttr.getValue()) ||
2661 hasGang(deviceTypeAttr.getValue()) ||
2662 getGangValue(mlir::acc::GangArgType::Num,
2663 deviceTypeAttr.getValue()) ||
2664 getGangValue(mlir::acc::GangArgType::Dim,
2665 deviceTypeAttr.getValue()) ||
2666 getGangValue(mlir::acc::GangArgType::Static,
2667 deviceTypeAttr.getValue()))
2668 return emitError() <<
"gang, worker or vector cannot appear with seq";
2672 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2673 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
2674 "privatizations",
false)))
2677 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2678 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
2679 "reductions",
false)))
2682 if (getCombined().has_value() &&
2683 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
2684 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
2685 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
2686 return emitError(
"unexpected combined constructs attribute");
2690 if (getRegion().empty())
2691 return emitError(
"expected non-empty body.");
2694 if (isContainerLike()) {
2697 uint64_t collapseCount = getCollapseValue().value_or(1);
2698 if (getCollapseAttr()) {
2699 for (
auto collapseEntry : getCollapseAttr()) {
2700 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
2701 if (intAttr.getValue().getZExtValue() > collapseCount)
2702 collapseCount = intAttr.getValue().getZExtValue();
2710 bool foundSibling =
false;
2712 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
2714 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
2716 foundSibling =
true;
2721 expectedParent = op;
2724 if (collapseCount == 0)
2730 return emitError(
"found sibling loops inside container-like acc.loop");
2731 if (collapseCount != 0)
2732 return emitError(
"failed to find enough loop-like operations inside "
2733 "container-like acc.loop");
2739 unsigned LoopOp::getNumDataOperands() {
2740 return getReductionOperands().size() + getPrivateOperands().size();
2743 Value LoopOp::getDataOperand(
unsigned i) {
2744 unsigned numOptional =
2745 getLowerbound().size() + getUpperbound().size() + getStep().size();
2746 numOptional += getGangOperands().size();
2747 numOptional += getVectorOperands().size();
2748 numOptional += getWorkerNumOperands().size();
2749 numOptional += getTileOperands().size();
2750 numOptional += getCacheOperands().size();
2751 return getOperand(numOptional + i);
2756 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
2760 bool LoopOp::hasIndependent() {
2764 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
2770 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
2778 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
2780 getVectorOperands(), deviceType);
2785 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
2793 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
2795 getWorkerNumOperands(), deviceType);
2800 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
2809 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
2811 getTileOperandsSegments(), deviceType);
2814 std::optional<int64_t> LoopOp::getCollapseValue() {
2818 std::optional<int64_t>
2819 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
2820 if (!getCollapseAttr())
2821 return std::nullopt;
2822 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
2824 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
2825 return intAttr.getValue().getZExtValue();
2827 return std::nullopt;
2830 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
2834 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
2835 mlir::acc::DeviceType deviceType) {
2836 if (getGangOperands().empty())
2838 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
2839 int32_t nbOperandsBefore = 0;
2840 for (
unsigned i = 0; i < *pos; ++i)
2841 nbOperandsBefore += (*getGangOperandsSegments())[i];
2844 .drop_front(nbOperandsBefore)
2845 .take_front((*getGangOperandsSegments())[*pos]);
2847 int32_t argTypeIdx = nbOperandsBefore;
2848 for (
auto value : values) {
2849 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2850 (*getGangOperandsArgType())[argTypeIdx]);
2851 if (gangArgTypeAttr.getValue() == gangArgType)
2861 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
2866 return {&getRegion()};
2910 if (!regionArgs.empty()) {
2911 p << acc::LoopOp::getControlKeyword() <<
"(";
2912 llvm::interleaveComma(regionArgs, p,
2914 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
2915 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
2916 <<
" : " << stepType <<
") ";
2923 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
2924 effectiveDeviceTypes));
2927 void acc::LoopOp::addIndependent(
2929 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
2930 context, getIndependentAttr(), effectiveDeviceTypes));
2935 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
2936 effectiveDeviceTypes));
2939 void acc::LoopOp::setCollapseForDeviceTypes(
2941 llvm::APInt value) {
2945 assert((getCollapseAttr() ==
nullptr) ==
2946 (getCollapseDeviceTypeAttr() ==
nullptr));
2947 assert(value.getBitWidth() == 64);
2949 if (getCollapseAttr()) {
2950 for (
const auto &existing :
2951 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
2952 newValues.push_back(std::get<0>(existing));
2953 newDeviceTypes.push_back(std::get<1>(existing));
2957 if (effectiveDeviceTypes.empty()) {
2960 newValues.push_back(
2962 newDeviceTypes.push_back(
2965 for (DeviceType dt : effectiveDeviceTypes) {
2966 newValues.push_back(
2973 setCollapseDeviceTypeAttr(
ArrayAttr::get(context, newDeviceTypes));
2976 void acc::LoopOp::setTileForDeviceTypes(
2980 if (getTileOperandsSegments())
2981 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
2983 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2984 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
2985 getTileOperandsMutable(), segments));
2987 setTileOperandsSegments(segments);
2990 void acc::LoopOp::addVectorOperand(
2993 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2994 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
2995 newValue, getVectorOperandsMutable()));
2998 void acc::LoopOp::addEmptyVector(
3000 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3001 effectiveDeviceTypes));
3004 void acc::LoopOp::addWorkerNumOperand(
3007 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3008 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3009 newValue, getWorkerNumOperandsMutable()));
3012 void acc::LoopOp::addEmptyWorker(
3014 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3015 effectiveDeviceTypes));
3018 void acc::LoopOp::addEmptyGang(
3020 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3021 effectiveDeviceTypes));
3024 bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3025 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3026 return attr.getValue() == dt;
3028 auto testFromArr = [=](ArrayAttr arr) ->
bool {
3029 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3032 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3034 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3036 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3042 bool acc::LoopOp::hasDefaultGangWorkerVector() {
3043 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3044 hasGang() || getGangValue(GangArgType::Num) ||
3045 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3049 acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3050 if (hasSeq(deviceType))
3051 return LoopParMode::loop_seq;
3052 if (hasAuto(deviceType))
3053 return LoopParMode::loop_auto;
3054 if (hasIndependent(deviceType))
3055 return LoopParMode::loop_independent;
3057 return LoopParMode::loop_seq;
3059 return LoopParMode::loop_auto;
3060 assert(hasIndependent() &&
3061 "loop must have default auto, seq, or independent");
3062 return LoopParMode::loop_independent;
3065 void acc::LoopOp::addGangOperands(
3070 getGangOperandsSegments())
3071 llvm::copy(*existingSegments, std::back_inserter(segments));
3073 unsigned beforeCount = segments.size();
3075 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3076 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3077 getGangOperandsMutable(), segments));
3079 setGangOperandsSegments(segments);
3086 unsigned numAdded = segments.size() - beforeCount;
3090 if (getGangOperandsArgTypeAttr())
3091 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3093 for (
auto i : llvm::index_range(0u, numAdded)) {
3094 llvm::transform(argTypes, std::back_inserter(gangTypes),
3095 [=](mlir::acc::GangArgType gangTy) {
3105 void acc::LoopOp::addPrivatization(
MLIRContext *context,
3106 mlir::acc::PrivateOp op,
3107 mlir::acc::PrivateRecipeOp recipe) {
3108 getPrivateOperandsMutable().append(op.getResult());
3112 if (getPrivatizationRecipesAttr())
3113 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
3120 void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3121 mlir::acc::ReductionRecipeOp recipe) {
3122 getReductionOperandsMutable().append(op.getResult());
3126 if (getReductionRecipesAttr())
3127 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
3142 if (getOperands().empty() && !getDefaultAttr())
3143 return emitError(
"at least one operand or the default attribute "
3144 "must appear on the data operation");
3146 for (
mlir::Value operand : getDataClauseOperands())
3147 if (isa<BlockArgument>(operand) ||
3148 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3149 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3150 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3151 operand.getDefiningOp()))
3152 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3155 if (
failed(checkWaitAndAsyncConflict<acc::DataOp>(*
this)))
3161 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3163 Value DataOp::getDataOperand(
unsigned i) {
3164 unsigned numOptional = getIfCond() ? 1 : 0;
3166 numOptional += getWaitOperands().size();
3167 return getOperand(numOptional + i);
3170 bool acc::DataOp::hasAsyncOnly() {
3174 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3182 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3189 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3198 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3200 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3201 getHasWaitDevnum(), deviceType);
3208 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3210 getWaitOperandsSegments(), getHasWaitDevnum(),
3214 void acc::DataOp::addAsyncOnly(
3216 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3217 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3220 void acc::DataOp::addAsyncOperand(
3223 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3224 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3225 getAsyncOperandsMutable()));
3228 void acc::DataOp::addWaitOnly(
MLIRContext *context,
3230 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3231 effectiveDeviceTypes));
3234 void acc::DataOp::addWaitOperands(
3239 if (getWaitOperandsSegments())
3240 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3242 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3243 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3244 getWaitOperandsMutable(), segments));
3245 setWaitOperandsSegments(segments);
3248 if (getHasWaitDevnumAttr())
3249 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3252 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3265 if (getDataClauseOperands().empty())
3266 return emitError(
"at least one operand must be present in dataOperands on "
3267 "the exit data operation");
3271 if (getAsyncOperand() && getAsync())
3272 return emitError(
"async attribute cannot appear with asyncOperand");
3276 if (!getWaitOperands().empty() && getWait())
3277 return emitError(
"wait attribute cannot appear with waitOperands");
3279 if (getWaitDevnum() && getWaitOperands().empty())
3280 return emitError(
"wait_devnum cannot appear without waitOperands");
3285 unsigned ExitDataOp::getNumDataOperands() {
3286 return getDataClauseOperands().size();
3289 Value ExitDataOp::getDataOperand(
unsigned i) {
3290 unsigned numOptional = getIfCond() ? 1 : 0;
3291 numOptional += getAsyncOperand() ? 1 : 0;
3292 numOptional += getWaitDevnum() ? 1 : 0;
3293 return getOperand(getWaitOperands().size() + numOptional + i);
3298 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3301 void ExitDataOp::addAsyncOnly(
MLIRContext *context,
3303 assert(effectiveDeviceTypes.empty());
3304 assert(!getAsyncAttr());
3305 assert(!getAsyncOperand());
3310 void ExitDataOp::addAsyncOperand(
3313 assert(effectiveDeviceTypes.empty());
3314 assert(!getAsyncAttr());
3315 assert(!getAsyncOperand());
3317 getAsyncOperandMutable().append(newValue);
3320 void ExitDataOp::addWaitOnly(
MLIRContext *context,
3322 assert(effectiveDeviceTypes.empty());
3323 assert(!getWaitAttr());
3324 assert(getWaitOperands().empty());
3325 assert(!getWaitDevnum());
3330 void ExitDataOp::addWaitOperands(
3333 assert(effectiveDeviceTypes.empty());
3334 assert(!getWaitAttr());
3335 assert(getWaitOperands().empty());
3336 assert(!getWaitDevnum());
3341 getWaitDevnumMutable().append(newValues.front());
3342 newValues = newValues.drop_front();
3345 getWaitOperandsMutable().append(newValues);
3356 if (getDataClauseOperands().empty())
3357 return emitError(
"at least one operand must be present in dataOperands on "
3358 "the enter data operation");
3362 if (getAsyncOperand() && getAsync())
3363 return emitError(
"async attribute cannot appear with asyncOperand");
3367 if (!getWaitOperands().empty() && getWait())
3368 return emitError(
"wait attribute cannot appear with waitOperands");
3370 if (getWaitDevnum() && getWaitOperands().empty())
3371 return emitError(
"wait_devnum cannot appear without waitOperands");
3373 for (
mlir::Value operand : getDataClauseOperands())
3374 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3375 operand.getDefiningOp()))
3376 return emitError(
"expect data entry operation as defining op");
3381 unsigned EnterDataOp::getNumDataOperands() {
3382 return getDataClauseOperands().size();
3385 Value EnterDataOp::getDataOperand(
unsigned i) {
3386 unsigned numOptional = getIfCond() ? 1 : 0;
3387 numOptional += getAsyncOperand() ? 1 : 0;
3388 numOptional += getWaitDevnum() ? 1 : 0;
3389 return getOperand(getWaitOperands().size() + numOptional + i);
3394 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
3397 void EnterDataOp::addAsyncOnly(
3399 assert(effectiveDeviceTypes.empty());
3400 assert(!getAsyncAttr());
3401 assert(!getAsyncOperand());
3406 void EnterDataOp::addAsyncOperand(
3409 assert(effectiveDeviceTypes.empty());
3410 assert(!getAsyncAttr());
3411 assert(!getAsyncOperand());
3413 getAsyncOperandMutable().append(newValue);
3416 void EnterDataOp::addWaitOnly(
MLIRContext *context,
3418 assert(effectiveDeviceTypes.empty());
3419 assert(!getWaitAttr());
3420 assert(getWaitOperands().empty());
3421 assert(!getWaitDevnum());
3426 void EnterDataOp::addWaitOperands(
3429 assert(effectiveDeviceTypes.empty());
3430 assert(!getWaitAttr());
3431 assert(getWaitOperands().empty());
3432 assert(!getWaitDevnum());
3437 getWaitDevnumMutable().append(newValues.front());
3438 newValues = newValues.drop_front();
3441 getWaitOperandsMutable().append(newValues);
3460 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3467 if (
Value writeVal = op.getWriteOpVal()) {
3477 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3483 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3484 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3486 return dyn_cast<AtomicReadOp>(getSecondOp());
3489 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3490 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3492 return dyn_cast<AtomicWriteOp>(getSecondOp());
3495 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3496 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3498 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3501 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
3507 template <
typename Op>
3508 static LogicalResult
3510 bool requireAtLeastOneOperand =
true) {
3511 if (operands.empty() && requireAtLeastOneOperand)
3514 "at least one operand must appear on the declare operation");
3517 if (isa<BlockArgument>(operand) ||
3518 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3519 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3520 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3521 operand.getDefiningOp()))
3523 "expect valid declare data entry operation or acc.getdeviceptr "
3527 assert(var &&
"declare operands can only be data entry operations which "
3530 std::optional<mlir::acc::DataClause> dataClauseOptional{
3532 assert(dataClauseOptional.has_value() &&
3533 "declare operands can only be data entry operations which must have "
3535 (void)dataClauseOptional;
3569 acc::DeviceType dtype) {
3570 unsigned parallelism = 0;
3571 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
3572 parallelism += op.hasWorker(dtype) ? 1 : 0;
3573 parallelism += op.hasVector(dtype) ? 1 : 0;
3574 parallelism += op.hasSeq(dtype) ? 1 : 0;
3579 unsigned baseParallelism =
3582 if (baseParallelism > 1)
3583 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
3584 "be present at the same time";
3586 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3588 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
3593 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3594 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
3595 "be present at the same time";
3602 mlir::ArrayAttr &bindIdName,
3603 mlir::ArrayAttr &bindStrName,
3604 mlir::ArrayAttr &deviceIdTypes,
3605 mlir::ArrayAttr &deviceStrTypes) {
3612 mlir::Attribute newAttr;
3613 bool isSymbolRefAttr;
3614 auto parseResult = parser.parseAttribute(newAttr);
3615 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
3616 bindIdNameAttrs.push_back(symbolRefAttr);
3617 isSymbolRefAttr = true;
3618 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
3619 bindStrNameAttrs.push_back(stringAttr);
3620 isSymbolRefAttr =
false;
3625 if (isSymbolRefAttr) {
3626 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3627 parser.getContext(), mlir::acc::DeviceType::None));
3629 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3630 parser.getContext(), mlir::acc::DeviceType::None));
3633 if (isSymbolRefAttr) {
3634 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
3635 parser.parseRSquare())
3638 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
3639 parser.parseRSquare())
3647 bindIdName =
ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
3648 bindStrName =
ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
3649 deviceIdTypes =
ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
3650 deviceStrTypes =
ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
3656 std::optional<mlir::ArrayAttr> bindIdName,
3657 std::optional<mlir::ArrayAttr> bindStrName,
3658 std::optional<mlir::ArrayAttr> deviceIdTypes,
3659 std::optional<mlir::ArrayAttr> deviceStrTypes) {
3666 allBindNames.append(bindIdName->begin(), bindIdName->end());
3667 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
3672 allBindNames.append(bindStrName->begin(), bindStrName->end());
3673 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
3677 if (!allBindNames.empty())
3678 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
3679 [&](
const auto &pair) {
3680 p << std::get<0>(pair);
3686 mlir::ArrayAttr &gang,
3687 mlir::ArrayAttr &gangDim,
3688 mlir::ArrayAttr &gangDimDeviceTypes) {
3691 gangDimDeviceTypeAttrs;
3692 bool needCommaBeforeOperands =
false;
3705 if (parser.parseAttribute(gangAttrs.emplace_back()))
3712 needCommaBeforeOperands =
true;
3719 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
3720 parser.parseColon() ||
3721 parser.parseAttribute(gangDimAttrs.emplace_back()))
3723 if (succeeded(parser.parseOptionalLSquare())) {
3724 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
3725 parser.parseRSquare())
3728 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
3729 parser.getContext(), mlir::acc::DeviceType::None));
3735 if (
failed(parser.parseRParen()))
3740 gangDimDeviceTypes =
3747 std::optional<mlir::ArrayAttr> gang,
3748 std::optional<mlir::ArrayAttr> gangDim,
3749 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
3752 gang->size() == 1) {
3753 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
3766 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
3767 [&](
const auto &pair) {
3768 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
3769 p << std::get<0>(pair);
3777 mlir::ArrayAttr &deviceTypes) {
3790 if (parser.parseAttribute(attributes.emplace_back()))
3804 std::optional<mlir::ArrayAttr> deviceTypes) {
3807 auto deviceTypeAttr =
3808 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
3818 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3826 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
3832 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
3838 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
3842 std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
3843 RoutineOp::getBindNameValue() {
3847 std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
3848 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
3851 return std::nullopt;
3854 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
3855 auto attr = (*getBindIdName())[*pos];
3856 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
3857 assert(symbolRefAttr &&
"expected SymbolRef");
3858 return symbolRefAttr;
3861 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
3862 auto attr = (*getBindStrName())[*pos];
3863 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
3864 assert(stringAttr &&
"expected String");
3868 return std::nullopt;
3873 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
3877 std::optional<int64_t> RoutineOp::getGangDimValue() {
3881 std::optional<int64_t>
3882 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
3884 return std::nullopt;
3885 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
3886 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
3887 return intAttr.getInt();
3889 return std::nullopt;
3900 return emitOpError(
"cannot be nested in a compute operation");
3904 void acc::InitOp::addDeviceType(
MLIRContext *context,
3905 mlir::acc::DeviceType deviceType) {
3907 if (getDeviceTypesAttr())
3908 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3922 return emitOpError(
"cannot be nested in a compute operation");
3926 void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
3927 mlir::acc::DeviceType deviceType) {
3929 if (getDeviceTypesAttr())
3930 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
3944 return emitOpError(
"cannot be nested in a compute operation");
3945 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
3946 return emitOpError(
"at least one default_async, device_num, or device_type "
3947 "operand must appear");
3957 if (getDataClauseOperands().empty())
3958 return emitError(
"at least one value must be present in dataOperands");
3961 getAsyncOperandsDeviceTypeAttr(),
3966 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3967 getWaitOperandsDeviceTypeAttr(),
"wait")))
3970 if (
failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*
this)))
3973 for (
mlir::Value operand : getDataClauseOperands())
3974 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
3975 operand.getDefiningOp()))
3976 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3982 unsigned UpdateOp::getNumDataOperands() {
3983 return getDataClauseOperands().size();
3986 Value UpdateOp::getDataOperand(
unsigned i) {
3988 numOptional += getIfCond() ? 1 : 0;
3989 return getOperand(getWaitOperands().size() + numOptional + i);
3994 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
3997 bool UpdateOp::hasAsyncOnly() {
4001 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4009 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4019 bool UpdateOp::hasWaitOnly() {
4023 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4032 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4034 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4035 getHasWaitDevnum(), deviceType);
4042 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4044 getWaitOperandsSegments(), getHasWaitDevnum(),
4050 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4051 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4054 void UpdateOp::addAsyncOperand(
4057 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4058 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4059 getAsyncOperandsMutable()));
4064 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4065 effectiveDeviceTypes));
4068 void UpdateOp::addWaitOperands(
4073 if (getWaitOperandsSegments())
4074 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4076 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4077 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4078 getWaitOperandsMutable(), segments));
4079 setWaitOperandsSegments(segments);
4082 if (getHasWaitDevnumAttr())
4083 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4086 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4098 if (getAsyncOperand() && getAsync())
4099 return emitError(
"async attribute cannot appear with asyncOperand");
4101 if (getWaitDevnum() && getWaitOperands().empty())
4102 return emitError(
"wait_devnum cannot appear without waitOperands");
4107 #define GET_OP_CLASSES
4108 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4110 #define GET_ATTRDEF_CLASSES
4111 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4113 #define GET_TYPEDEF_CLASSES
4114 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4125 .Case<ACC_DATA_ENTRY_OPS>(
4126 [&](
auto entry) {
return entry.getVarPtr(); })
4127 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4128 [&](
auto exit) {
return exit.getVarPtr(); })
4146 [&](
auto entry) {
return entry.getVarType(); })
4147 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4148 [&](
auto exit) {
return exit.getVarType(); })
4158 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4159 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
4169 [&](
auto dataClause) {
return dataClause.getAccVar(); })
4178 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
4188 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4190 dataClause.getBounds().begin(), dataClause.getBounds().end());
4202 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4204 dataClause.getAsyncOperands().begin(),
4205 dataClause.getAsyncOperands().end());
4216 return dataClause.getAsyncOperandsDeviceTypeAttr();
4224 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
4231 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
4238 std::optional<mlir::acc::DataClause>
4243 .Case<ACC_DATA_ENTRY_OPS>(
4244 [&](
auto entry) {
return entry.getDataClause(); })
4252 [&](
auto entry) {
return entry.getImplicit(); })
4261 [&](
auto entry) {
return entry.getDataClauseOperands(); })
4263 return dataOperands;
4271 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
4273 return dataOperands;
4279 if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) {
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Operation * getParentOp()
Return the parent operation this region is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
static WalkResult advance()
static WalkResult interrupt()
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
mlir::Operation * getEnclosingComputeOp(mlir::Region ®ion)
Used to obtain the enclosing compute construct operation that contains the provided region.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.