23 #include "llvm/ADT/SmallSet.h"
24 #include "llvm/ADT/TypeSwitch.h"
25 #include "llvm/Support/LogicalResult.h"
31 #include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
32 #include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
33 #include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
34 #include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
35 #include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
39 static bool isScalarLikeType(
Type type) {
43 struct MemRefPointerLikeModel
44 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
47 return cast<MemRefType>(pointer).getElementType();
50 mlir::acc::VariableTypeCategory
53 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
54 return mappableTy.getTypeCategory(varPtr);
56 auto memrefTy = cast<MemRefType>(pointer);
57 if (!memrefTy.hasRank()) {
60 return mlir::acc::VariableTypeCategory::uncategorized;
63 if (memrefTy.getRank() == 0) {
64 if (isScalarLikeType(memrefTy.getElementType())) {
65 return mlir::acc::VariableTypeCategory::scalar;
69 return mlir::acc::VariableTypeCategory::uncategorized;
73 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
74 return mlir::acc::VariableTypeCategory::array;
78 StringRef varName,
Type varType,
Value originalVar,
79 bool &needsFree)
const {
80 auto memrefTy = cast<MemRefType>(pointer);
84 if (memrefTy.hasStaticShape()) {
86 return memref::AllocaOp::create(builder, loc, memrefTy).getResult();
91 if (originalVar && originalVar.
getType() == memrefTy &&
94 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
95 if (memrefTy.isDynamicDim(i)) {
99 memref::DimOp::create(builder, loc, originalVar, indexValue);
100 dynamicSizes.push_back(dimSize);
106 return memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes)
116 Type varType)
const {
119 Value valueToInspect = allocRes ? allocRes : memrefValue;
122 Value currentValue = valueToInspect;
127 while (currentValue) {
130 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
131 originalAlloc = definingOp;
136 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
137 currentValue = castOp.getSource();
142 if (
auto reinterpretCastOp =
143 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
144 currentValue = reinterpretCastOp.getSource();
156 if (isa<memref::AllocaOp>(originalAlloc)) {
160 if (isa<memref::AllocOp>(originalAlloc)) {
162 memref::DeallocOp::create(builder, loc, memrefValue);
175 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
176 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
182 if (destMemref && srcMemref &&
183 destMemref.getType().getElementType() ==
184 srcMemref.getType().getElementType() &&
185 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
186 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
194 struct LLVMPointerPointerLikeModel
195 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
196 LLVM::LLVMPointerType> {
204 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
205 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
208 if (existingDeviceTypes)
209 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
211 if (newDeviceTypes.empty())
212 deviceTypes.push_back(
215 for (DeviceType dt : newDeviceTypes)
227 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
228 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
233 if (existingDeviceTypes)
234 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
236 if (newDeviceTypes.empty()) {
237 argCollection.
append(arguments);
238 segments.push_back(arguments.size());
239 deviceTypes.push_back(
243 for (DeviceType dt : newDeviceTypes) {
244 argCollection.
append(arguments);
245 segments.push_back(arguments.size());
253 mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
254 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
258 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
259 newDeviceTypes, arguments,
260 argCollection, segments);
268 void OpenACCDialect::initialize() {
271 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
274 #define GET_ATTRDEF_LIST
275 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
278 #define GET_TYPEDEF_LIST
279 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
285 MemRefType::attachInterface<MemRefPointerLikeModel>(*
getContext());
286 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
295 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
299 mlir::acc::DeviceType deviceType) {
303 for (
auto attr : *arrayAttr) {
304 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
305 if (deviceTypeAttr.getValue() == deviceType)
313 std::optional<mlir::ArrayAttr> deviceTypes) {
318 llvm::interleaveComma(*deviceTypes, p,
324 mlir::acc::DeviceType deviceType) {
325 unsigned segmentIdx = 0;
326 for (
auto attr : segments) {
327 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
328 if (deviceTypeAttr.getValue() == deviceType)
329 return std::make_optional(segmentIdx);
339 mlir::acc::DeviceType deviceType) {
341 return range.take_front(0);
342 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
343 int32_t nbOperandsBefore = 0;
344 for (
unsigned i = 0; i < *pos; ++i)
345 nbOperandsBefore += (*segments)[i];
346 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
348 return range.take_front(0);
355 std::optional<mlir::ArrayAttr> hasWaitDevnum,
356 mlir::acc::DeviceType deviceType) {
359 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
360 if (hasWaitDevnum->getValue()[*pos])
371 std::optional<mlir::ArrayAttr> hasWaitDevnum,
372 mlir::acc::DeviceType deviceType) {
377 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
378 if (hasWaitDevnum && *hasWaitDevnum) {
379 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
380 if (boolAttr.getValue())
381 return range.drop_front(1);
387 template <
typename Op>
389 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
391 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
396 op.hasAsyncOnly(dtype))
398 "asyncOnly attribute cannot appear with asyncOperand");
403 op.hasWaitOnly(dtype))
404 return op.
emitError(
"wait attribute cannot appear with waitOperands");
409 template <
typename Op>
412 return op.
emitError(
"must have var operand");
415 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
416 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
417 return op.
emitError(
"var must be mappable or pointer-like");
420 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
421 op.getVarType() == op.getVar().getType())
422 return op.
emitError(
"varType must capture the element type of var");
427 template <
typename Op>
429 if (op.getVar().getType() != op.getAccVar().getType())
430 return op.
emitError(
"input and output types must match");
435 template <
typename Op>
437 if (op.getModifiers() != acc::DataClauseModifier::none)
438 return op.
emitError(
"no data clause modifiers are allowed");
442 template <
typename Op>
445 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
447 "invalid data clause modifiers: " +
448 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
470 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
501 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
513 mlir::TypeAttr &varTypeAttr) {
530 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
532 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
541 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
549 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
550 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
552 if (typeToCheckAgainst != varType) {
563 auto extent = getExtent();
564 auto upperbound = getUpperbound();
565 if (!extent && !upperbound)
566 return emitError(
"expected extent or upperbound.");
576 "data clause associated with private operation must match its intent");
589 return emitError(
"data clause associated with firstprivate operation must "
603 return emitError(
"data clause associated with reduction operation must "
617 return emitError(
"data clause associated with deviceptr operation must "
634 "data clause associated with present operation must match its intent");
649 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
654 "data clause associated with copyin operation must match its intent"
655 " or specify original clause this operation was decomposed from");
661 acc::DataClauseModifier::always |
662 acc::DataClauseModifier::capture)))
667 bool acc::CopyinOp::isCopyinReadonly() {
668 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
669 acc::bitEnumContainsAny(getModifiers(),
670 acc::DataClauseModifier::readonly);
683 "data clause associated with create operation must match its intent"
684 " or specify original clause this operation was decomposed from");
692 acc::DataClauseModifier::always |
693 acc::DataClauseModifier::capture)))
698 bool acc::CreateOp::isCreateZero() {
700 return getDataClause() == acc::DataClause::acc_create_zero ||
702 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
710 return emitError(
"data clause associated with no_create operation must "
727 "data clause associated with attach operation must match its intent");
742 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
743 return emitError(
"data clause associated with device_resident operation "
744 "must match its intent");
761 "data clause associated with link operation must match its intent");
781 "data clause associated with copyout operation must match its intent"
782 " or specify original clause this operation was decomposed from");
784 return emitError(
"must have both host and device pointers");
790 acc::DataClauseModifier::always |
791 acc::DataClauseModifier::capture)))
796 bool acc::CopyoutOp::isCopyoutZero() {
797 return getDataClause() == acc::DataClause::acc_copyout_zero ||
798 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
813 getDataClause() != acc::DataClause::acc_declare_device_resident &&
816 "data clause associated with delete operation must match its intent"
817 " or specify original clause this operation was decomposed from");
819 return emitError(
"must have device pointer");
823 acc::DataClauseModifier::readonly |
824 acc::DataClauseModifier::always |
825 acc::DataClauseModifier::capture)))
838 "data clause associated with detach operation must match its intent"
839 " or specify original clause this operation was decomposed from");
841 return emitError(
"must have device pointer");
855 "data clause associated with host operation must match its intent"
856 " or specify original clause this operation was decomposed from");
858 return emitError(
"must have both host and device pointers");
875 "data clause associated with device operation must match its intent"
876 " or specify original clause this operation was decomposed from");
893 "data clause associated with use_device operation must match its intent"
894 " or specify original clause this operation was decomposed from");
912 "data clause associated with cache operation must match its intent"
913 " or specify original clause this operation was decomposed from");
923 bool acc::CacheOp::isCacheReadonly() {
924 return getDataClause() == acc::DataClause::acc_cache_readonly ||
925 acc::bitEnumContainsAny(getModifiers(),
926 acc::DataClauseModifier::readonly);
929 template <
typename StructureOp>
931 unsigned nRegions = 1) {
934 for (
unsigned i = 0; i < nRegions; ++i)
935 regions.push_back(state.addRegion());
937 for (
Region *region : regions)
945 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
952 template <
typename OpTy>
956 LogicalResult matchAndRewrite(OpTy op,
959 Value ifCond = op.getIfCond();
963 IntegerAttr constAttr;
966 if (constAttr.getInt())
967 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
979 assert(region.
hasOneBlock() &&
"expected single-block region");
991 template <
typename OpTy>
992 struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
995 LogicalResult matchAndRewrite(OpTy op,
998 Value ifCond = op.getIfCond();
1002 IntegerAttr constAttr;
1005 if (constAttr.getInt())
1006 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1021 static std::unique_ptr<Block> createInitRegion(
OpBuilder &builder,
Location loc,
1022 Type varType, StringRef varName,
1028 for (
Value bound : bounds) {
1029 argTypes.push_back(bound.getType());
1030 argLocs.push_back(loc);
1033 auto initBlock = std::make_unique<Block>();
1034 initBlock->addArguments(argTypes, argLocs);
1037 Value privatizedValue;
1040 Value blockArgVar = initBlock->getArgument(0);
1043 if (isa<MappableType>(varType)) {
1044 auto mappableTy = cast<MappableType>(varType);
1045 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1046 privatizedValue = mappableTy.generatePrivateInit(
1047 builder, loc, typedVar, varName, bounds, {}, needsFree);
1048 if (!privatizedValue)
1051 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1052 auto pointerLikeTy = cast<PointerLikeType>(varType);
1054 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1055 blockArgVar, needsFree);
1056 if (!privatizedValue)
1061 acc::YieldOp::create(builder, loc, privatizedValue);
1069 static std::unique_ptr<Block> createCopyRegion(
OpBuilder &builder,
Location loc,
1076 for (
Value bound : bounds) {
1077 copyArgTypes.push_back(bound.getType());
1078 copyArgLocs.push_back(loc);
1081 auto copyBlock = std::make_unique<Block>();
1082 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1085 bool isMappable = isa<MappableType>(varType);
1086 bool isPointerLike = isa<PointerLikeType>(varType);
1089 if (isMappable && !isPointerLike)
1093 if (isPointerLike) {
1094 auto pointerLikeTy = cast<PointerLikeType>(varType);
1095 Value originalArg = copyBlock->getArgument(0);
1096 Value privatizedArg = copyBlock->getArgument(1);
1099 if (!pointerLikeTy.genCopy(
1106 acc::TerminatorOp::create(builder, loc);
1113 static std::unique_ptr<Block> createDestroyRegion(
OpBuilder &builder,
1121 for (
Value bound : bounds) {
1122 destroyArgTypes.push_back(bound.getType());
1123 destroyArgLocs.push_back(loc);
1126 auto destroyBlock = std::make_unique<Block>();
1127 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1130 bool isMappable = isa<MappableType>(varType);
1131 bool isPointerLike = isa<PointerLikeType>(varType);
1134 if (isMappable && !isPointerLike)
1137 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1138 auto pointerLikeTy = cast<PointerLikeType>(varType);
1139 auto privatizedArg =
1140 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1142 if (!pointerLikeTy.genFree(builder, loc, privatizedArg, allocRes, varType))
1145 acc::TerminatorOp::create(builder, loc);
1147 return destroyBlock;
1157 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1159 if (optional && region.
empty())
1163 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1167 return op->
emitOpError() <<
"expects " << regionName
1170 << regionType <<
" type";
1173 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1174 if (yieldOp.getOperands().size() != 1 ||
1175 yieldOp.getOperands().getTypes()[0] != type)
1176 return op->
emitOpError() <<
"expects " << regionName
1178 "yield a value of the "
1179 << regionType <<
" type";
1185 LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1187 "privatization",
"init",
getType(),
1191 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1197 std::optional<PrivateRecipeOp>
1199 StringRef recipeName,
Type varType,
1202 bool isMappable = isa<MappableType>(varType);
1203 bool isPointerLike = isa<PointerLikeType>(varType);
1206 if (!isMappable && !isPointerLike)
1207 return std::nullopt;
1215 bool needsFree =
false;
1217 createInitRegion(builder, loc, varType, varName, bounds, needsFree);
1219 return std::nullopt;
1222 std::unique_ptr<Block> destroyBlock;
1225 auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
1226 Value allocRes = yieldOp.getOperand(0);
1228 destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
1230 return std::nullopt;
1236 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1239 recipe.getInitRegion().push_back(initBlock.release());
1241 recipe.getDestroyRegion().push_back(destroyBlock.release());
1250 LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1252 "privatization",
"init",
getType(),
1256 if (getCopyRegion().empty())
1257 return emitOpError() <<
"expects non-empty copy region";
1262 return emitOpError() <<
"expects copy region with two arguments of the "
1263 "privatization type";
1265 if (getDestroyRegion().empty())
1269 "privatization",
"destroy",
1276 std::optional<FirstprivateRecipeOp>
1278 StringRef recipeName,
Type varType,
1281 bool isMappable = isa<MappableType>(varType);
1282 bool isPointerLike = isa<PointerLikeType>(varType);
1285 if (!isMappable && !isPointerLike)
1286 return std::nullopt;
1294 bool needsFree =
false;
1296 createInitRegion(builder, loc, varType, varName, bounds, needsFree);
1298 return std::nullopt;
1300 auto copyBlock = createCopyRegion(builder, loc, varType, bounds);
1302 return std::nullopt;
1305 std::unique_ptr<Block> destroyBlock;
1308 auto yieldOp = cast<acc::YieldOp>(initBlock->getTerminator());
1309 Value allocRes = yieldOp.getOperand(0);
1311 destroyBlock = createDestroyRegion(builder, loc, varType, allocRes, bounds);
1313 return std::nullopt;
1319 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1322 recipe.getInitRegion().push_back(initBlock.release());
1323 recipe.getCopyRegion().push_back(copyBlock.release());
1325 recipe.getDestroyRegion().push_back(destroyBlock.release());
1334 LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1340 if (getCombinerRegion().empty())
1341 return emitOpError() <<
"expects non-empty combiner region";
1343 Block &reductionBlock = getCombinerRegion().
front();
1347 return emitOpError() <<
"expects combiner region with the first two "
1348 <<
"arguments of the reduction type";
1350 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1351 if (yieldOp.getOperands().size() != 1 ||
1352 yieldOp.getOperands().getTypes()[0] !=
getType())
1353 return emitOpError() <<
"expects combiner region to yield a value "
1354 "of the reduction type";
1370 if (parser.parseAttribute(attributes.emplace_back()) ||
1371 parser.parseArrow() ||
1372 parser.parseOperand(operands.emplace_back()) ||
1373 parser.parseColonType(types.emplace_back()))
1387 std::optional<mlir::ArrayAttr> attributes) {
1388 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
1389 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
1390 << std::get<1>(it).getType();
1399 template <
typename Op>
1403 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1404 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1405 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1406 operand.getDefiningOp()))
1408 "expect data entry/exit operation or acc.getdeviceptr "
1413 template <
typename Op>
1414 static LogicalResult
1417 llvm::StringRef symbolName,
bool checkOperandType =
true) {
1418 if (!operands.empty()) {
1419 if (!attributes || attributes->size() != operands.size())
1421 <<
"expected as many " << symbolName <<
" symbol reference as "
1422 << operandName <<
" operands";
1426 <<
"unexpected " << symbolName <<
" symbol reference";
1431 for (
auto args : llvm::zip(operands, *attributes)) {
1434 if (!set.insert(operand).second)
1436 << operandName <<
" operand appears more than once";
1439 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1440 auto decl = SymbolTable::lookupNearestSymbolFrom<Op>(op, symbolRef);
1443 <<
"expected symbol reference " << symbolRef <<
" to point to a "
1444 << operandName <<
" declaration";
1446 if (checkOperandType && decl.getType() && decl.getType() != varType)
1447 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
1448 <<
") to be the same type as " << operandName
1449 <<
" declaration (" << decl.getType() <<
")";
1455 unsigned ParallelOp::getNumDataOperands() {
1456 return getReductionOperands().size() + getPrivateOperands().size() +
1457 getFirstprivateOperands().size() + getDataClauseOperands().size();
1460 Value ParallelOp::getDataOperand(
unsigned i) {
1462 numOptional += getNumGangs().size();
1463 numOptional += getNumWorkers().size();
1464 numOptional += getVectorLength().size();
1465 numOptional += getIfCond() ? 1 : 0;
1466 numOptional += getSelfCond() ? 1 : 0;
1467 return getOperand(getWaitOperands().size() + numOptional + i);
1470 template <
typename Op>
1472 ArrayAttr deviceTypes,
1473 llvm::StringRef keyword) {
1474 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1475 return op.
emitOpError() << keyword <<
" operands count must match "
1476 << keyword <<
" device_type count";
1480 template <
typename Op>
1483 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1484 std::size_t numOperandsInSegments = 0;
1485 std::size_t nbOfSegments = 0;
1488 for (
auto segCount : segments.
asArrayRef()) {
1489 if (maxInSegment != 0 && segCount > maxInSegment)
1490 return op.
emitOpError() << keyword <<
" expects a maximum of "
1491 << maxInSegment <<
" values per segment";
1492 numOperandsInSegments += segCount;
1497 if ((numOperandsInSegments != operands.size()) ||
1498 (!deviceTypes && !operands.empty()))
1500 << keyword <<
" operand count does not match count in segments";
1501 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1503 << keyword <<
" segment count does not match device_type count";
1508 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
1509 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
1510 "privatizations",
false)))
1512 if (
failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
1513 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1514 "firstprivate",
"firstprivatizations",
false)))
1516 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
1517 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1518 "reductions",
false)))
1522 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1523 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1527 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1528 getWaitOperandsDeviceTypeAttr(),
"wait")))
1532 getNumWorkersDeviceTypeAttr(),
1537 getVectorLengthDeviceTypeAttr(),
1542 getAsyncOperandsDeviceTypeAttr(),
1546 if (
failed(checkWaitAndAsyncConflict<acc::ParallelOp>(*
this)))
1549 return checkDataOperands<acc::ParallelOp>(*
this, getDataClauseOperands());
1555 mlir::acc::DeviceType deviceType) {
1558 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1563 bool acc::ParallelOp::hasAsyncOnly() {
1567 bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1575 mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1580 mlir::Value acc::ParallelOp::getNumWorkersValue() {
1585 acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1590 mlir::Value acc::ParallelOp::getVectorLengthValue() {
1595 acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1597 getVectorLength(), deviceType);
1605 ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1607 getNumGangsSegments(), deviceType);
1610 bool acc::ParallelOp::hasWaitOnly() {
1614 bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1623 ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1625 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1626 getHasWaitDevnum(), deviceType);
1633 mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1635 getWaitOperandsSegments(), getHasWaitDevnum(),
1651 odsBuilder, odsState, asyncOperands,
nullptr,
1652 nullptr, waitOperands,
nullptr,
1654 nullptr, numGangs,
nullptr,
1655 nullptr, numWorkers,
1656 nullptr, vectorLength,
1657 nullptr, ifCond, selfCond,
1658 nullptr, reductionOperands,
nullptr,
1659 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
1660 nullptr, dataClauseOperands,
1664 void acc::ParallelOp::addNumWorkersOperand(
1667 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1668 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1669 getNumWorkersMutable()));
1671 void acc::ParallelOp::addVectorLengthOperand(
1674 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1675 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1676 getVectorLengthMutable()));
1679 void acc::ParallelOp::addAsyncOnly(
1681 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1682 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1685 void acc::ParallelOp::addAsyncOperand(
1688 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1689 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1690 getAsyncOperandsMutable()));
1693 void acc::ParallelOp::addNumGangsOperands(
1697 if (getNumGangsSegments())
1698 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1700 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1701 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1702 getNumGangsMutable(), segments));
1704 setNumGangsSegments(segments);
1706 void acc::ParallelOp::addWaitOnly(
1708 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1709 effectiveDeviceTypes));
1711 void acc::ParallelOp::addWaitOperands(
1716 if (getWaitOperandsSegments())
1717 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1719 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1720 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1721 getWaitOperandsMutable(), segments));
1722 setWaitOperandsSegments(segments);
1725 if (getHasWaitDevnumAttr())
1726 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1729 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1734 void acc::ParallelOp::addPrivatization(
MLIRContext *context,
1735 mlir::acc::PrivateOp op,
1736 mlir::acc::PrivateRecipeOp recipe) {
1737 getPrivateOperandsMutable().append(op.getResult());
1741 if (getPrivatizationRecipesAttr())
1742 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
1749 void acc::ParallelOp::addFirstPrivatization(
1750 MLIRContext *context, mlir::acc::FirstprivateOp op,
1751 mlir::acc::FirstprivateRecipeOp recipe) {
1752 getFirstprivateOperandsMutable().append(op.getResult());
1756 if (getFirstprivatizationRecipesAttr())
1757 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
1764 void acc::ParallelOp::addReduction(
MLIRContext *context,
1765 mlir::acc::ReductionOp op,
1766 mlir::acc::ReductionRecipeOp recipe) {
1767 getReductionOperandsMutable().append(op.getResult());
1771 if (getReductionRecipesAttr())
1772 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
1791 int32_t crtOperandsSize = operands.size();
1794 if (parser.parseOperand(operands.emplace_back()) ||
1795 parser.parseColonType(types.emplace_back()))
1800 seg.push_back(operands.size() - crtOperandsSize);
1824 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1826 p <<
" [" << attr <<
"]";
1831 std::optional<mlir::ArrayAttr> deviceTypes,
1832 std::optional<mlir::DenseI32ArrayAttr> segments) {
1834 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1836 llvm::interleaveComma(
1837 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1838 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1858 int32_t crtOperandsSize = operands.size();
1862 if (parser.parseOperand(operands.emplace_back()) ||
1863 parser.parseColonType(types.emplace_back()))
1869 seg.push_back(operands.size() - crtOperandsSize);
1895 std::optional<mlir::DenseI32ArrayAttr> segments) {
1897 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1899 llvm::interleaveComma(
1900 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1901 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1914 mlir::ArrayAttr &keywordOnly) {
1918 bool needCommaBeforeOperands =
false;
1931 if (parser.parseAttribute(keywordAttrs.emplace_back()))
1938 needCommaBeforeOperands =
true;
1948 int32_t crtOperandsSize = operands.size();
1960 if (parser.parseOperand(operands.emplace_back()) ||
1961 parser.parseColonType(types.emplace_back()))
1967 seg.push_back(operands.size() - crtOperandsSize);
1996 if (attrs->size() != 1)
1998 if (
auto deviceTypeAttr =
1999 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2006 std::optional<mlir::ArrayAttr> deviceTypes,
2007 std::optional<mlir::DenseI32ArrayAttr> segments,
2008 std::optional<mlir::ArrayAttr> hasDevNum,
2009 std::optional<mlir::ArrayAttr> keywordOnly) {
2022 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2024 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2025 if (boolAttr && boolAttr.getValue())
2027 llvm::interleaveComma(
2028 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2029 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2046 if (parser.parseOperand(operands.emplace_back()) ||
2047 parser.parseColonType(types.emplace_back()))
2049 if (succeeded(parser.parseOptionalLSquare())) {
2050 if (parser.parseAttribute(attributes.emplace_back()) ||
2051 parser.parseRSquare())
2054 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2055 parser.getContext(), mlir::acc::DeviceType::None));
2069 std::optional<mlir::ArrayAttr> deviceTypes) {
2072 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2073 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2082 mlir::ArrayAttr &keywordOnlyDeviceType) {
2085 bool needCommaBeforeOperands =
false;
2091 keywordOnlyDeviceType =
2100 if (parser.parseAttribute(
2101 keywordOnlyDeviceTypeAttributes.emplace_back()))
2108 needCommaBeforeOperands =
true;
2116 if (parser.parseOperand(operands.emplace_back()) ||
2117 parser.parseColonType(types.emplace_back()))
2119 if (succeeded(parser.parseOptionalLSquare())) {
2120 if (parser.parseAttribute(attributes.emplace_back()) ||
2121 parser.parseRSquare())
2124 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2125 parser.getContext(), mlir::acc::DeviceType::None));
2131 if (
failed(parser.parseRParen()))
2143 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2145 if (operands.begin() == operands.end() &&
2161 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2162 mlir::Type &operandType, mlir::UnitAttr &attr) {
2185 std::optional<mlir::Value> operand,
2187 mlir::UnitAttr attr) {
2209 if (parser.parseOperand(operands.emplace_back()))
2217 if (parser.parseType(types.emplace_back()))
2232 mlir::UnitAttr attr) {
2237 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2239 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2245 mlir::acc::CombinedConstructsTypeAttr &attr) {
2248 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2251 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2254 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2257 "expected compute construct name");
2265 mlir::acc::CombinedConstructsTypeAttr attr) {
2267 switch (attr.getValue()) {
2268 case mlir::acc::CombinedConstructsType::KernelsLoop:
2271 case mlir::acc::CombinedConstructsType::ParallelLoop:
2274 case mlir::acc::CombinedConstructsType::SerialLoop:
2285 unsigned SerialOp::getNumDataOperands() {
2286 return getReductionOperands().size() + getPrivateOperands().size() +
2287 getFirstprivateOperands().size() + getDataClauseOperands().size();
2290 Value SerialOp::getDataOperand(
unsigned i) {
2292 numOptional += getIfCond() ? 1 : 0;
2293 numOptional += getSelfCond() ? 1 : 0;
2294 return getOperand(getWaitOperands().size() + numOptional + i);
2297 bool acc::SerialOp::hasAsyncOnly() {
2301 bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2309 mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2314 bool acc::SerialOp::hasWaitOnly() {
2318 bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2327 SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2329 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2330 getHasWaitDevnum(), deviceType);
2337 mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2339 getWaitOperandsSegments(), getHasWaitDevnum(),
2344 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
2345 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
2346 "privatizations",
false)))
2348 if (
failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
2349 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
2350 "firstprivate",
"firstprivatizations",
false)))
2352 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
2353 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
2354 "reductions",
false)))
2358 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2359 getWaitOperandsDeviceTypeAttr(),
"wait")))
2363 getAsyncOperandsDeviceTypeAttr(),
2367 if (
failed(checkWaitAndAsyncConflict<acc::SerialOp>(*
this)))
2370 return checkDataOperands<acc::SerialOp>(*
this, getDataClauseOperands());
2373 void acc::SerialOp::addAsyncOnly(
2375 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2376 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2379 void acc::SerialOp::addAsyncOperand(
2382 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2383 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2384 getAsyncOperandsMutable()));
2387 void acc::SerialOp::addWaitOnly(
2389 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2390 effectiveDeviceTypes));
2392 void acc::SerialOp::addWaitOperands(
2397 if (getWaitOperandsSegments())
2398 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2400 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2401 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2402 getWaitOperandsMutable(), segments));
2403 setWaitOperandsSegments(segments);
2406 if (getHasWaitDevnumAttr())
2407 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2410 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2415 void acc::SerialOp::addPrivatization(
MLIRContext *context,
2416 mlir::acc::PrivateOp op,
2417 mlir::acc::PrivateRecipeOp recipe) {
2418 getPrivateOperandsMutable().append(op.getResult());
2422 if (getPrivatizationRecipesAttr())
2423 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
2430 void acc::SerialOp::addFirstPrivatization(
2431 MLIRContext *context, mlir::acc::FirstprivateOp op,
2432 mlir::acc::FirstprivateRecipeOp recipe) {
2433 getFirstprivateOperandsMutable().append(op.getResult());
2437 if (getFirstprivatizationRecipesAttr())
2438 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
2445 void acc::SerialOp::addReduction(
MLIRContext *context,
2446 mlir::acc::ReductionOp op,
2447 mlir::acc::ReductionRecipeOp recipe) {
2448 getReductionOperandsMutable().append(op.getResult());
2452 if (getReductionRecipesAttr())
2453 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
2464 unsigned KernelsOp::getNumDataOperands() {
2465 return getDataClauseOperands().size();
2468 Value KernelsOp::getDataOperand(
unsigned i) {
2470 numOptional += getWaitOperands().size();
2471 numOptional += getNumGangs().size();
2472 numOptional += getNumWorkers().size();
2473 numOptional += getVectorLength().size();
2474 numOptional += getIfCond() ? 1 : 0;
2475 numOptional += getSelfCond() ? 1 : 0;
2476 return getOperand(numOptional + i);
2479 bool acc::KernelsOp::hasAsyncOnly() {
2483 bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2491 mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2496 mlir::Value acc::KernelsOp::getNumWorkersValue() {
2501 acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2506 mlir::Value acc::KernelsOp::getVectorLengthValue() {
2511 acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2513 getVectorLength(), deviceType);
2521 KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2523 getNumGangsSegments(), deviceType);
2526 bool acc::KernelsOp::hasWaitOnly() {
2530 bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2539 KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2541 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2542 getHasWaitDevnum(), deviceType);
2549 mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2551 getWaitOperandsSegments(), getHasWaitDevnum(),
2557 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2558 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2562 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2563 getWaitOperandsDeviceTypeAttr(),
"wait")))
2567 getNumWorkersDeviceTypeAttr(),
2572 getVectorLengthDeviceTypeAttr(),
2577 getAsyncOperandsDeviceTypeAttr(),
2581 if (
failed(checkWaitAndAsyncConflict<acc::KernelsOp>(*
this)))
2584 return checkDataOperands<acc::KernelsOp>(*
this, getDataClauseOperands());
2587 void acc::KernelsOp::addNumWorkersOperand(
2590 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2591 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2592 getNumWorkersMutable()));
2595 void acc::KernelsOp::addVectorLengthOperand(
2598 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2599 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2600 getVectorLengthMutable()));
2602 void acc::KernelsOp::addAsyncOnly(
2604 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2605 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2608 void acc::KernelsOp::addAsyncOperand(
2611 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2612 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2613 getAsyncOperandsMutable()));
2616 void acc::KernelsOp::addNumGangsOperands(
2620 if (getNumGangsSegmentsAttr())
2621 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2623 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2624 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2625 getNumGangsMutable(), segments));
2627 setNumGangsSegments(segments);
2630 void acc::KernelsOp::addWaitOnly(
2632 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2633 effectiveDeviceTypes));
2635 void acc::KernelsOp::addWaitOperands(
2640 if (getWaitOperandsSegments())
2641 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2643 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2644 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2645 getWaitOperandsMutable(), segments));
2646 setWaitOperandsSegments(segments);
2649 if (getHasWaitDevnumAttr())
2650 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2653 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2663 if (getDataClauseOperands().empty())
2664 return emitError(
"at least one operand must appear on the host_data "
2667 for (
mlir::Value operand : getDataClauseOperands())
2668 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2669 return emitError(
"expect data entry operation as defining op");
2675 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2687 bool &needCommaBetweenValues,
bool &newValue) {
2694 attributes.push_back(gangArgType);
2695 needCommaBetweenValues =
true;
2706 mlir::ArrayAttr &gangOnlyDeviceType) {
2711 bool needCommaBetweenValues =
false;
2712 bool needCommaBeforeOperands =
false;
2718 gangOnlyDeviceType =
2727 if (parser.parseAttribute(
2728 gangOnlyDeviceTypeAttributes.emplace_back()))
2735 needCommaBeforeOperands =
true;
2739 mlir::acc::GangArgType::Num);
2741 mlir::acc::GangArgType::Dim);
2743 parser.
getContext(), mlir::acc::GangArgType::Static);
2746 if (needCommaBeforeOperands) {
2747 needCommaBeforeOperands =
false;
2754 int32_t crtOperandsSize = gangOperands.size();
2756 bool newValue =
false;
2757 bool needValue =
false;
2758 if (needCommaBetweenValues) {
2766 gangOperands, gangOperandsType,
2767 gangArgTypeAttributes, argNum,
2768 needCommaBetweenValues, newValue)))
2771 gangOperands, gangOperandsType,
2772 gangArgTypeAttributes, argDim,
2773 needCommaBetweenValues, newValue)))
2776 gangOperands, gangOperandsType,
2777 gangArgTypeAttributes, argStatic,
2778 needCommaBetweenValues, newValue)))
2781 if (!newValue && needValue) {
2783 "new value expected after comma");
2791 if (gangOperands.empty())
2794 "expect at least one of num, dim or static values");
2800 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
2808 seg.push_back(gangOperands.size() - crtOperandsSize);
2816 gangArgTypeAttributes.end());
2821 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2830 std::optional<mlir::ArrayAttr> gangArgTypes,
2831 std::optional<mlir::ArrayAttr> deviceTypes,
2832 std::optional<mlir::DenseI32ArrayAttr> segments,
2833 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2835 if (operands.begin() == operands.end() &&
2850 llvm::interleaveComma(
llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2852 llvm::interleaveComma(
2853 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2854 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2855 (*gangArgTypes)[opIdx]);
2856 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2857 p << LoopOp::getGangNumKeyword();
2858 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2859 p << LoopOp::getGangDimKeyword();
2860 else if (gangArgTypeAttr.getValue() ==
2861 mlir::acc::GangArgType::Static)
2862 p << LoopOp::getGangStaticKeyword();
2863 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
2874 std::optional<mlir::ArrayAttr> segments,
2875 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2878 for (
auto attr : *segments) {
2879 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2880 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2888 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2891 for (
auto attr : deviceTypes) {
2892 auto deviceTypeAttr =
2893 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2894 if (!deviceTypeAttr)
2896 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2903 if (getUpperbound().size() != getStep().size())
2904 return emitError() <<
"number of upperbounds expected to be the same as "
2907 if (getUpperbound().size() != getLowerbound().size())
2908 return emitError() <<
"number of upperbounds expected to be the same as "
2909 "number of lowerbounds";
2911 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2912 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2913 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
2914 <<
" as upperbound size";
2917 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
2918 return emitOpError() <<
"collapse device_type attr must be define when"
2919 <<
" collapse attr is present";
2921 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
2922 getCollapseAttr().getValue().size() !=
2923 getCollapseDeviceTypeAttr().getValue().size())
2924 return emitOpError() <<
"collapse attribute count must match collapse"
2925 <<
" device_type count";
2927 return emitOpError()
2928 <<
"duplicate device_type found in collapseDeviceType attribute";
2931 if (!getGangOperands().empty()) {
2932 if (!getGangOperandsArgType())
2933 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
2934 <<
" when gang operands are present";
2936 if (getGangOperands().size() !=
2937 getGangOperandsArgTypeAttr().getValue().size())
2938 return emitOpError() <<
"gangOperandsArgType attribute count must match"
2939 <<
" gangOperands count";
2942 return emitOpError() <<
"duplicate device_type found in gang attribute";
2945 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
2946 getGangOperandsDeviceTypeAttr(),
"gang")))
2951 return emitOpError() <<
"duplicate device_type found in worker attribute";
2953 return emitOpError() <<
"duplicate device_type found in "
2954 "workerNumOperandsDeviceType attribute";
2956 getWorkerNumOperandsDeviceTypeAttr(),
2962 return emitOpError() <<
"duplicate device_type found in vector attribute";
2964 return emitOpError() <<
"duplicate device_type found in "
2965 "vectorOperandsDeviceType attribute";
2967 getVectorOperandsDeviceTypeAttr(),
2972 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
2973 getTileOperandsDeviceTypeAttr(),
"tile")))
2977 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
2981 return emitError() <<
"only one of auto, independent, seq can be present "
2987 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
2990 bool hasDefaultSeq =
2992 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
2995 bool hasDefaultIndependent =
2996 getIndependentAttr()
2998 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3001 bool hasDefaultAuto =
3003 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3006 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3008 <<
"at least one of auto, independent, seq must be present";
3013 for (
auto attr : getSeqAttr()) {
3014 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3015 if (hasVector(deviceTypeAttr.getValue()) ||
3016 getVectorValue(deviceTypeAttr.getValue()) ||
3017 hasWorker(deviceTypeAttr.getValue()) ||
3018 getWorkerValue(deviceTypeAttr.getValue()) ||
3019 hasGang(deviceTypeAttr.getValue()) ||
3020 getGangValue(mlir::acc::GangArgType::Num,
3021 deviceTypeAttr.getValue()) ||
3022 getGangValue(mlir::acc::GangArgType::Dim,
3023 deviceTypeAttr.getValue()) ||
3024 getGangValue(mlir::acc::GangArgType::Static,
3025 deviceTypeAttr.getValue()))
3026 return emitError() <<
"gang, worker or vector cannot appear with seq";
3030 if (
failed(checkSymOperandList<mlir::acc::PrivateRecipeOp>(
3031 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
3032 "privatizations",
false)))
3035 if (
failed(checkSymOperandList<mlir::acc::FirstprivateRecipeOp>(
3036 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
3037 "firstprivate",
"firstprivatizations",
false)))
3040 if (
failed(checkSymOperandList<mlir::acc::ReductionRecipeOp>(
3041 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
3042 "reductions",
false)))
3045 if (getCombined().has_value() &&
3046 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3047 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3048 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3049 return emitError(
"unexpected combined constructs attribute");
3053 if (getRegion().empty())
3054 return emitError(
"expected non-empty body.");
3057 if (isContainerLike()) {
3060 uint64_t collapseCount = getCollapseValue().value_or(1);
3061 if (getCollapseAttr()) {
3062 for (
auto collapseEntry : getCollapseAttr()) {
3063 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3064 if (intAttr.getValue().getZExtValue() > collapseCount)
3065 collapseCount = intAttr.getValue().getZExtValue();
3073 bool foundSibling =
false;
3075 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3077 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3079 foundSibling =
true;
3084 expectedParent = op;
3087 if (collapseCount == 0)
3093 return emitError(
"found sibling loops inside container-like acc.loop");
3094 if (collapseCount != 0)
3095 return emitError(
"failed to find enough loop-like operations inside "
3096 "container-like acc.loop");
3102 unsigned LoopOp::getNumDataOperands() {
3103 return getReductionOperands().size() + getPrivateOperands().size() +
3104 getFirstprivateOperands().size();
3107 Value LoopOp::getDataOperand(
unsigned i) {
3108 unsigned numOptional =
3109 getLowerbound().size() + getUpperbound().size() + getStep().size();
3110 numOptional += getGangOperands().size();
3111 numOptional += getVectorOperands().size();
3112 numOptional += getWorkerNumOperands().size();
3113 numOptional += getTileOperands().size();
3114 numOptional += getCacheOperands().size();
3115 return getOperand(numOptional + i);
3120 bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3124 bool LoopOp::hasIndependent() {
3128 bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3134 bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3142 mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3144 getVectorOperands(), deviceType);
3149 bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3157 mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3159 getWorkerNumOperands(), deviceType);
3164 bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3173 LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3175 getTileOperandsSegments(), deviceType);
3178 std::optional<int64_t> LoopOp::getCollapseValue() {
3182 std::optional<int64_t>
3183 LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3184 if (!getCollapseAttr())
3185 return std::nullopt;
3186 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3188 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3189 return intAttr.getValue().getZExtValue();
3191 return std::nullopt;
3194 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3198 mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3199 mlir::acc::DeviceType deviceType) {
3200 if (getGangOperands().empty())
3202 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3203 int32_t nbOperandsBefore = 0;
3204 for (
unsigned i = 0; i < *pos; ++i)
3205 nbOperandsBefore += (*getGangOperandsSegments())[i];
3208 .drop_front(nbOperandsBefore)
3209 .take_front((*getGangOperandsSegments())[*pos]);
3211 int32_t argTypeIdx = nbOperandsBefore;
3212 for (
auto value : values) {
3213 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3214 (*getGangOperandsArgType())[argTypeIdx]);
3215 if (gangArgTypeAttr.getValue() == gangArgType)
3225 bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3230 return {&getRegion()};
3274 if (!regionArgs.empty()) {
3275 p << acc::LoopOp::getControlKeyword() <<
"(";
3276 llvm::interleaveComma(regionArgs, p,
3278 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3279 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3280 <<
" : " << stepType <<
") ";
3287 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3288 effectiveDeviceTypes));
3291 void acc::LoopOp::addIndependent(
3293 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3294 context, getIndependentAttr(), effectiveDeviceTypes));
3299 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3300 effectiveDeviceTypes));
3303 void acc::LoopOp::setCollapseForDeviceTypes(
3305 llvm::APInt value) {
3309 assert((getCollapseAttr() ==
nullptr) ==
3310 (getCollapseDeviceTypeAttr() ==
nullptr));
3311 assert(value.getBitWidth() == 64);
3313 if (getCollapseAttr()) {
3314 for (
const auto &existing :
3315 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3316 newValues.push_back(std::get<0>(existing));
3317 newDeviceTypes.push_back(std::get<1>(existing));
3321 if (effectiveDeviceTypes.empty()) {
3324 newValues.push_back(
3326 newDeviceTypes.push_back(
3329 for (DeviceType dt : effectiveDeviceTypes) {
3330 newValues.push_back(
3337 setCollapseDeviceTypeAttr(
ArrayAttr::get(context, newDeviceTypes));
3340 void acc::LoopOp::setTileForDeviceTypes(
3344 if (getTileOperandsSegments())
3345 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3347 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3348 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3349 getTileOperandsMutable(), segments));
3351 setTileOperandsSegments(segments);
3354 void acc::LoopOp::addVectorOperand(
3357 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3358 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3359 newValue, getVectorOperandsMutable()));
3362 void acc::LoopOp::addEmptyVector(
3364 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3365 effectiveDeviceTypes));
3368 void acc::LoopOp::addWorkerNumOperand(
3371 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3372 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3373 newValue, getWorkerNumOperandsMutable()));
3376 void acc::LoopOp::addEmptyWorker(
3378 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3379 effectiveDeviceTypes));
3382 void acc::LoopOp::addEmptyGang(
3384 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3385 effectiveDeviceTypes));
3388 bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3389 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3390 return attr.getValue() == dt;
3392 auto testFromArr = [=](ArrayAttr arr) ->
bool {
3393 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3396 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3398 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3400 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3406 bool acc::LoopOp::hasDefaultGangWorkerVector() {
3407 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3408 hasGang() || getGangValue(GangArgType::Num) ||
3409 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3413 acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3414 if (hasSeq(deviceType))
3415 return LoopParMode::loop_seq;
3416 if (hasAuto(deviceType))
3417 return LoopParMode::loop_auto;
3418 if (hasIndependent(deviceType))
3419 return LoopParMode::loop_independent;
3421 return LoopParMode::loop_seq;
3423 return LoopParMode::loop_auto;
3424 assert(hasIndependent() &&
3425 "loop must have default auto, seq, or independent");
3426 return LoopParMode::loop_independent;
3429 void acc::LoopOp::addGangOperands(
3434 getGangOperandsSegments())
3435 llvm::copy(*existingSegments, std::back_inserter(segments));
3437 unsigned beforeCount = segments.size();
3439 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3440 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3441 getGangOperandsMutable(), segments));
3443 setGangOperandsSegments(segments);
3450 unsigned numAdded = segments.size() - beforeCount;
3454 if (getGangOperandsArgTypeAttr())
3455 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3457 for (
auto i : llvm::index_range(0u, numAdded)) {
3458 llvm::transform(argTypes, std::back_inserter(gangTypes),
3459 [=](mlir::acc::GangArgType gangTy) {
3469 void acc::LoopOp::addPrivatization(
MLIRContext *context,
3470 mlir::acc::PrivateOp op,
3471 mlir::acc::PrivateRecipeOp recipe) {
3472 getPrivateOperandsMutable().append(op.getResult());
3476 if (getPrivatizationRecipesAttr())
3477 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
3484 void acc::LoopOp::addFirstPrivatization(
3485 MLIRContext *context, mlir::acc::FirstprivateOp op,
3486 mlir::acc::FirstprivateRecipeOp recipe) {
3487 getFirstprivateOperandsMutable().append(op.getResult());
3491 if (getFirstprivatizationRecipesAttr())
3492 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
3499 void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3500 mlir::acc::ReductionRecipeOp recipe) {
3501 getReductionOperandsMutable().append(op.getResult());
3505 if (getReductionRecipesAttr())
3506 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
3521 if (getOperands().empty() && !getDefaultAttr())
3522 return emitError(
"at least one operand or the default attribute "
3523 "must appear on the data operation");
3525 for (
mlir::Value operand : getDataClauseOperands())
3526 if (isa<BlockArgument>(operand) ||
3527 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3528 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3529 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3530 operand.getDefiningOp()))
3531 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3534 if (
failed(checkWaitAndAsyncConflict<acc::DataOp>(*
this)))
3540 unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3542 Value DataOp::getDataOperand(
unsigned i) {
3543 unsigned numOptional = getIfCond() ? 1 : 0;
3545 numOptional += getWaitOperands().size();
3546 return getOperand(numOptional + i);
3549 bool acc::DataOp::hasAsyncOnly() {
3553 bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3561 mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3568 bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3577 DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3579 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3580 getHasWaitDevnum(), deviceType);
3587 mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3589 getWaitOperandsSegments(), getHasWaitDevnum(),
3593 void acc::DataOp::addAsyncOnly(
3595 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3596 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3599 void acc::DataOp::addAsyncOperand(
3602 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3603 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3604 getAsyncOperandsMutable()));
3607 void acc::DataOp::addWaitOnly(
MLIRContext *context,
3609 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3610 effectiveDeviceTypes));
3613 void acc::DataOp::addWaitOperands(
3618 if (getWaitOperandsSegments())
3619 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3621 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3622 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3623 getWaitOperandsMutable(), segments));
3624 setWaitOperandsSegments(segments);
3627 if (getHasWaitDevnumAttr())
3628 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3631 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3644 if (getDataClauseOperands().empty())
3645 return emitError(
"at least one operand must be present in dataOperands on "
3646 "the exit data operation");
3650 if (getAsyncOperand() && getAsync())
3651 return emitError(
"async attribute cannot appear with asyncOperand");
3655 if (!getWaitOperands().empty() && getWait())
3656 return emitError(
"wait attribute cannot appear with waitOperands");
3658 if (getWaitDevnum() && getWaitOperands().empty())
3659 return emitError(
"wait_devnum cannot appear without waitOperands");
3664 unsigned ExitDataOp::getNumDataOperands() {
3665 return getDataClauseOperands().size();
3668 Value ExitDataOp::getDataOperand(
unsigned i) {
3669 unsigned numOptional = getIfCond() ? 1 : 0;
3670 numOptional += getAsyncOperand() ? 1 : 0;
3671 numOptional += getWaitDevnum() ? 1 : 0;
3672 return getOperand(getWaitOperands().size() + numOptional + i);
3677 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3680 void ExitDataOp::addAsyncOnly(
MLIRContext *context,
3682 assert(effectiveDeviceTypes.empty());
3683 assert(!getAsyncAttr());
3684 assert(!getAsyncOperand());
3689 void ExitDataOp::addAsyncOperand(
3692 assert(effectiveDeviceTypes.empty());
3693 assert(!getAsyncAttr());
3694 assert(!getAsyncOperand());
3696 getAsyncOperandMutable().append(newValue);
3699 void ExitDataOp::addWaitOnly(
MLIRContext *context,
3701 assert(effectiveDeviceTypes.empty());
3702 assert(!getWaitAttr());
3703 assert(getWaitOperands().empty());
3704 assert(!getWaitDevnum());
3709 void ExitDataOp::addWaitOperands(
3712 assert(effectiveDeviceTypes.empty());
3713 assert(!getWaitAttr());
3714 assert(getWaitOperands().empty());
3715 assert(!getWaitDevnum());
3720 getWaitDevnumMutable().append(newValues.front());
3721 newValues = newValues.drop_front();
3724 getWaitOperandsMutable().append(newValues);
3735 if (getDataClauseOperands().empty())
3736 return emitError(
"at least one operand must be present in dataOperands on "
3737 "the enter data operation");
3741 if (getAsyncOperand() && getAsync())
3742 return emitError(
"async attribute cannot appear with asyncOperand");
3746 if (!getWaitOperands().empty() && getWait())
3747 return emitError(
"wait attribute cannot appear with waitOperands");
3749 if (getWaitDevnum() && getWaitOperands().empty())
3750 return emitError(
"wait_devnum cannot appear without waitOperands");
3752 for (
mlir::Value operand : getDataClauseOperands())
3753 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3754 operand.getDefiningOp()))
3755 return emitError(
"expect data entry operation as defining op");
3760 unsigned EnterDataOp::getNumDataOperands() {
3761 return getDataClauseOperands().size();
3764 Value EnterDataOp::getDataOperand(
unsigned i) {
3765 unsigned numOptional = getIfCond() ? 1 : 0;
3766 numOptional += getAsyncOperand() ? 1 : 0;
3767 numOptional += getWaitDevnum() ? 1 : 0;
3768 return getOperand(getWaitOperands().size() + numOptional + i);
3773 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
3776 void EnterDataOp::addAsyncOnly(
3778 assert(effectiveDeviceTypes.empty());
3779 assert(!getAsyncAttr());
3780 assert(!getAsyncOperand());
3785 void EnterDataOp::addAsyncOperand(
3788 assert(effectiveDeviceTypes.empty());
3789 assert(!getAsyncAttr());
3790 assert(!getAsyncOperand());
3792 getAsyncOperandMutable().append(newValue);
3795 void EnterDataOp::addWaitOnly(
MLIRContext *context,
3797 assert(effectiveDeviceTypes.empty());
3798 assert(!getWaitAttr());
3799 assert(getWaitOperands().empty());
3800 assert(!getWaitDevnum());
3805 void EnterDataOp::addWaitOperands(
3808 assert(effectiveDeviceTypes.empty());
3809 assert(!getWaitAttr());
3810 assert(getWaitOperands().empty());
3811 assert(!getWaitDevnum());
3816 getWaitDevnumMutable().append(newValues.front());
3817 newValues = newValues.drop_front();
3820 getWaitOperandsMutable().append(newValues);
3839 LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3846 if (
Value writeVal = op.getWriteOpVal()) {
3856 LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3862 AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3863 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3865 return dyn_cast<AtomicReadOp>(getSecondOp());
3868 AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3869 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3871 return dyn_cast<AtomicWriteOp>(getSecondOp());
3874 AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3875 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3877 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3880 LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
3886 template <
typename Op>
3887 static LogicalResult
3889 bool requireAtLeastOneOperand =
true) {
3890 if (operands.empty() && requireAtLeastOneOperand)
3893 "at least one operand must appear on the declare operation");
3896 if (isa<BlockArgument>(operand) ||
3897 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3898 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3899 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3900 operand.getDefiningOp()))
3902 "expect valid declare data entry operation or acc.getdeviceptr "
3906 assert(var &&
"declare operands can only be data entry operations which "
3909 std::optional<mlir::acc::DataClause> dataClauseOptional{
3911 assert(dataClauseOptional.has_value() &&
3912 "declare operands can only be data entry operations which must have "
3914 (void)dataClauseOptional;
3948 acc::DeviceType dtype) {
3949 unsigned parallelism = 0;
3950 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
3951 parallelism += op.hasWorker(dtype) ? 1 : 0;
3952 parallelism += op.hasVector(dtype) ? 1 : 0;
3953 parallelism += op.hasSeq(dtype) ? 1 : 0;
3958 unsigned baseParallelism =
3961 if (baseParallelism > 1)
3962 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
3963 "be present at the same time";
3965 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
3967 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
3972 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
3973 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
3974 "be present at the same time";
3981 mlir::ArrayAttr &bindIdName,
3982 mlir::ArrayAttr &bindStrName,
3983 mlir::ArrayAttr &deviceIdTypes,
3984 mlir::ArrayAttr &deviceStrTypes) {
3991 mlir::Attribute newAttr;
3992 bool isSymbolRefAttr;
3993 auto parseResult = parser.parseAttribute(newAttr);
3994 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
3995 bindIdNameAttrs.push_back(symbolRefAttr);
3996 isSymbolRefAttr = true;
3997 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
3998 bindStrNameAttrs.push_back(stringAttr);
3999 isSymbolRefAttr =
false;
4004 if (isSymbolRefAttr) {
4005 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4006 parser.getContext(), mlir::acc::DeviceType::None));
4008 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4009 parser.getContext(), mlir::acc::DeviceType::None));
4012 if (isSymbolRefAttr) {
4013 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4014 parser.parseRSquare())
4017 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4018 parser.parseRSquare())
4026 bindIdName =
ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4027 bindStrName =
ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4028 deviceIdTypes =
ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4029 deviceStrTypes =
ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4035 std::optional<mlir::ArrayAttr> bindIdName,
4036 std::optional<mlir::ArrayAttr> bindStrName,
4037 std::optional<mlir::ArrayAttr> deviceIdTypes,
4038 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4045 allBindNames.append(bindIdName->begin(), bindIdName->end());
4046 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4051 allBindNames.append(bindStrName->begin(), bindStrName->end());
4052 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4056 if (!allBindNames.empty())
4057 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4058 [&](
const auto &pair) {
4059 p << std::get<0>(pair);
4065 mlir::ArrayAttr &gang,
4066 mlir::ArrayAttr &gangDim,
4067 mlir::ArrayAttr &gangDimDeviceTypes) {
4070 gangDimDeviceTypeAttrs;
4071 bool needCommaBeforeOperands =
false;
4084 if (parser.parseAttribute(gangAttrs.emplace_back()))
4091 needCommaBeforeOperands =
true;
4098 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4099 parser.parseColon() ||
4100 parser.parseAttribute(gangDimAttrs.emplace_back()))
4102 if (succeeded(parser.parseOptionalLSquare())) {
4103 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4104 parser.parseRSquare())
4107 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4108 parser.getContext(), mlir::acc::DeviceType::None));
4114 if (
failed(parser.parseRParen()))
4119 gangDimDeviceTypes =
4126 std::optional<mlir::ArrayAttr> gang,
4127 std::optional<mlir::ArrayAttr> gangDim,
4128 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4131 gang->size() == 1) {
4132 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4145 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4146 [&](
const auto &pair) {
4147 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4148 p << std::get<0>(pair);
4156 mlir::ArrayAttr &deviceTypes) {
4169 if (parser.parseAttribute(attributes.emplace_back()))
4183 std::optional<mlir::ArrayAttr> deviceTypes) {
4186 auto deviceTypeAttr =
4187 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4197 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4205 bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4211 bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4217 bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4221 std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4222 RoutineOp::getBindNameValue() {
4226 std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4227 RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4230 return std::nullopt;
4233 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4234 auto attr = (*getBindIdName())[*pos];
4235 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4236 assert(symbolRefAttr &&
"expected SymbolRef");
4237 return symbolRefAttr;
4240 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4241 auto attr = (*getBindStrName())[*pos];
4242 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4243 assert(stringAttr &&
"expected String");
4247 return std::nullopt;
4252 bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4256 std::optional<int64_t> RoutineOp::getGangDimValue() {
4260 std::optional<int64_t>
4261 RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4263 return std::nullopt;
4264 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4265 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4266 return intAttr.getInt();
4268 return std::nullopt;
4279 return emitOpError(
"cannot be nested in a compute operation");
4283 void acc::InitOp::addDeviceType(
MLIRContext *context,
4284 mlir::acc::DeviceType deviceType) {
4286 if (getDeviceTypesAttr())
4287 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4301 return emitOpError(
"cannot be nested in a compute operation");
4305 void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4306 mlir::acc::DeviceType deviceType) {
4308 if (getDeviceTypesAttr())
4309 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4323 return emitOpError(
"cannot be nested in a compute operation");
4324 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4325 return emitOpError(
"at least one default_async, device_num, or device_type "
4326 "operand must appear");
4336 if (getDataClauseOperands().empty())
4337 return emitError(
"at least one value must be present in dataOperands");
4340 getAsyncOperandsDeviceTypeAttr(),
4345 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4346 getWaitOperandsDeviceTypeAttr(),
"wait")))
4349 if (
failed(checkWaitAndAsyncConflict<acc::UpdateOp>(*
this)))
4352 for (
mlir::Value operand : getDataClauseOperands())
4353 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4354 operand.getDefiningOp()))
4355 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4361 unsigned UpdateOp::getNumDataOperands() {
4362 return getDataClauseOperands().size();
4365 Value UpdateOp::getDataOperand(
unsigned i) {
4367 numOptional += getIfCond() ? 1 : 0;
4368 return getOperand(getWaitOperands().size() + numOptional + i);
4373 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
4376 bool UpdateOp::hasAsyncOnly() {
4380 bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4388 mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4398 bool UpdateOp::hasWaitOnly() {
4402 bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4411 UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4413 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4414 getHasWaitDevnum(), deviceType);
4421 mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4423 getWaitOperandsSegments(), getHasWaitDevnum(),
4429 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4430 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4433 void UpdateOp::addAsyncOperand(
4436 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4437 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4438 getAsyncOperandsMutable()));
4443 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4444 effectiveDeviceTypes));
4447 void UpdateOp::addWaitOperands(
4452 if (getWaitOperandsSegments())
4453 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4455 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4456 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4457 getWaitOperandsMutable(), segments));
4458 setWaitOperandsSegments(segments);
4461 if (getHasWaitDevnumAttr())
4462 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4465 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4477 if (getAsyncOperand() && getAsync())
4478 return emitError(
"async attribute cannot appear with asyncOperand");
4480 if (getWaitDevnum() && getWaitOperands().empty())
4481 return emitError(
"wait_devnum cannot appear without waitOperands");
4486 #define GET_OP_CLASSES
4487 #include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4489 #define GET_ATTRDEF_CLASSES
4490 #include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4492 #define GET_TYPEDEF_CLASSES
4493 #include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4504 .Case<ACC_DATA_ENTRY_OPS>(
4505 [&](
auto entry) {
return entry.getVarPtr(); })
4506 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4507 [&](
auto exit) {
return exit.getVarPtr(); })
4525 [&](
auto entry) {
return entry.getVarType(); })
4526 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4527 [&](
auto exit) {
return exit.getVarType(); })
4537 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4538 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
4548 [&](
auto dataClause) {
return dataClause.getAccVar(); })
4557 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
4567 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4569 dataClause.getBounds().begin(), dataClause.getBounds().end());
4581 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4583 dataClause.getAsyncOperands().begin(),
4584 dataClause.getAsyncOperands().end());
4595 return dataClause.getAsyncOperandsDeviceTypeAttr();
4603 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
4610 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
4617 std::optional<mlir::acc::DataClause>
4622 .Case<ACC_DATA_ENTRY_OPS>(
4623 [&](
auto entry) {
return entry.getDataClause(); })
4631 [&](
auto entry) {
return entry.getImplicit(); })
4640 [&](
auto entry) {
return entry.getDataClauseOperands(); })
4642 return dataOperands;
4650 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
4652 return dataOperands;
4658 if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp)) {
static void copy(Location loc, Value dst, Value src, Value size, OpBuilder &builder)
Copies the given number of bytes from src to dst pointers.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static MLIRContext * getContext(OpFoldResult val)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t >> segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t >> segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
static Value max(ImplicitLocOpBuilder &builder, Value value, Value bound)
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
InsertPoint saveInsertionPoint() const
Return a saved insertion point.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
void restoreInsertionPoint(InsertPoint ip)
Restore the insert point to a previously saved point.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
Operation * getParentOp()
Return the parent operation this region is attached to.
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
Builder from ArrayRef<T>.
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
mlir::Operation * getEnclosingComputeOp(mlir::Region ®ion)
Used to obtain the enclosing compute construct operation that contains the provided region.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
constexpr void enumerate(std::tuple< Tys... > &tuple, CallbackT &&callback)
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
auto get(MLIRContext *context, Ts &&...params)
Helper method that injects context only if needed, this helps unify some of the attribute constructio...
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
LogicalResult verify(Operation *op, bool verifyRecursively=true)
Perform (potentially expensive) checks of invariants, used to detect compiler bugs,...
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.