24#include "llvm/ADT/SmallSet.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/LogicalResult.h"
32#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
33#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
40static bool isScalarLikeType(
Type type) {
48 if (!varName.empty()) {
49 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
55struct MemRefPointerLikeModel
56 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
58 return cast<T>(pointer).getElementType();
61 mlir::acc::VariableTypeCategory
64 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
65 return mappableTy.getTypeCategory(varPtr);
67 auto memrefTy = cast<T>(pointer);
68 if (!memrefTy.hasRank()) {
71 return mlir::acc::VariableTypeCategory::uncategorized;
74 if (memrefTy.getRank() == 0) {
75 if (isScalarLikeType(memrefTy.getElementType())) {
76 return mlir::acc::VariableTypeCategory::scalar;
80 return mlir::acc::VariableTypeCategory::uncategorized;
84 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
85 return mlir::acc::VariableTypeCategory::array;
88 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
89 StringRef varName, Type varType, Value originalVar,
90 bool &needsFree)
const {
91 auto memrefTy = cast<MemRefType>(pointer);
95 if (memrefTy.hasStaticShape()) {
97 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
98 attachVarNameAttr(allocaOp, builder, varName);
99 return allocaOp.getResult();
104 if (originalVar && originalVar.
getType() == memrefTy &&
105 memrefTy.hasRank()) {
106 SmallVector<Value> dynamicSizes;
107 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
108 if (memrefTy.isDynamicDim(i)) {
112 memref::DimOp::create(builder, loc, originalVar, indexValue);
113 dynamicSizes.push_back(dimSize);
120 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
121 attachVarNameAttr(allocOp, builder, varName);
122 return allocOp.getResult();
129 bool genFree(Type pointer, OpBuilder &builder, Location loc,
131 Type varType)
const {
134 Value valueToInspect = allocRes ? allocRes : memrefValue;
137 Value currentValue = valueToInspect;
138 Operation *originalAlloc =
nullptr;
142 while (currentValue) {
145 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
146 originalAlloc = definingOp;
151 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
152 currentValue = castOp.getSource();
157 if (
auto reinterpretCastOp =
158 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
159 currentValue = reinterpretCastOp.getSource();
171 if (isa<memref::AllocaOp>(originalAlloc)) {
175 if (isa<memref::AllocOp>(originalAlloc)) {
177 memref::DeallocOp::create(builder, loc, memrefValue);
186 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
190 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
191 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
197 if (destMemref && srcMemref &&
198 destMemref.getType().getElementType() ==
199 srcMemref.getType().getElementType() &&
200 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
201 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
208 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
210 Type valueType)
const {
215 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
219 auto memrefTy = memrefValue.
getType();
222 if (memrefTy.getRank() != 0)
225 return memref::LoadOp::create(builder, loc, memrefValue);
228 bool genStore(Type pointer, OpBuilder &builder, Location loc,
234 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
238 auto memrefTy = memrefValue.getType();
241 if (memrefTy.getRank() != 0)
244 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
249struct LLVMPointerPointerLikeModel
250 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
251 LLVM::LLVMPointerType> {
254 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
256 Type valueType)
const {
261 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
264 bool genStore(Type pointer, OpBuilder &builder, Location loc,
266 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
271struct MemrefAddressOfGlobalModel
272 :
public AddressOfGlobalOpInterface::ExternalModel<
273 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
274 SymbolRefAttr getSymbol(Operation *op)
const {
275 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
276 return getGlobalOp.getNameAttr();
280struct MemrefGlobalVariableModel
281 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
283 bool isConstant(Operation *op)
const {
284 auto globalOp = cast<memref::GlobalOp>(op);
285 return globalOp.getConstant();
288 Region *getInitRegion(Operation *op)
const {
298mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
299 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
302 if (existingDeviceTypes)
303 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
305 if (newDeviceTypes.empty())
306 deviceTypes.push_back(
307 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
309 for (DeviceType dt : newDeviceTypes)
310 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
312 return mlir::ArrayAttr::get(context, deviceTypes);
321mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
322 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
327 if (existingDeviceTypes)
328 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
330 if (newDeviceTypes.empty()) {
331 argCollection.
append(arguments);
332 segments.push_back(arguments.size());
333 deviceTypes.push_back(
334 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
337 for (DeviceType dt : newDeviceTypes) {
338 argCollection.
append(arguments);
339 segments.push_back(arguments.size());
340 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
343 return mlir::ArrayAttr::get(context, deviceTypes);
347mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
348 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
352 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
353 newDeviceTypes, arguments,
354 argCollection, segments);
362void OpenACCDialect::initialize() {
365#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
368#define GET_ATTRDEF_LIST
369#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
372#define GET_TYPEDEF_LIST
373#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
379 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
381 UnrankedMemRefType::attachInterface<
382 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
383 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
387 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
389 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
417void ParallelOp::getSuccessorRegions(
429void KernelEnvironmentOp::getSuccessorRegions(
441void HostDataOp::getSuccessorRegions(
452 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
456 mlir::acc::DeviceType deviceType) {
460 for (
auto attr : *arrayAttr) {
461 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
462 if (deviceTypeAttr.getValue() == deviceType)
470 std::optional<mlir::ArrayAttr> deviceTypes) {
475 llvm::interleaveComma(*deviceTypes, p,
481 mlir::acc::DeviceType deviceType) {
482 unsigned segmentIdx = 0;
483 for (
auto attr : segments) {
484 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
485 if (deviceTypeAttr.getValue() == deviceType)
486 return std::make_optional(segmentIdx);
496 mlir::acc::DeviceType deviceType) {
498 return range.take_front(0);
499 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
500 int32_t nbOperandsBefore = 0;
501 for (
unsigned i = 0; i < *pos; ++i)
502 nbOperandsBefore += (*segments)[i];
503 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
505 return range.take_front(0);
512 std::optional<mlir::ArrayAttr> hasWaitDevnum,
513 mlir::acc::DeviceType deviceType) {
516 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
517 if (hasWaitDevnum->getValue()[*pos])
528 std::optional<mlir::ArrayAttr> hasWaitDevnum,
529 mlir::acc::DeviceType deviceType) {
534 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
535 if (hasWaitDevnum && *hasWaitDevnum) {
536 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
537 if (boolAttr.getValue())
538 return range.drop_front(1);
544template <
typename Op>
546 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
548 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
553 op.hasAsyncOnly(dtype))
555 "asyncOnly attribute cannot appear with asyncOperand");
560 op.hasWaitOnly(dtype))
561 return op.
emitError(
"wait attribute cannot appear with waitOperands");
566template <
typename Op>
569 return op.
emitError(
"must have var operand");
572 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
573 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
574 return op.
emitError(
"var must be mappable or pointer-like");
577 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
578 op.getVarType() == op.getVar().getType())
579 return op.
emitError(
"varType must capture the element type of var");
584template <
typename Op>
586 if (op.getVar().getType() != op.getAccVar().getType())
587 return op.
emitError(
"input and output types must match");
592template <
typename Op>
594 if (op.getModifiers() != acc::DataClauseModifier::none)
595 return op.
emitError(
"no data clause modifiers are allowed");
599template <
typename Op>
602 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
604 "invalid data clause modifiers: " +
605 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
610template <
typename OpT,
typename RecipeOpT>
611static LogicalResult
checkRecipe(OpT op, llvm::StringRef operandName) {
616 !std::is_same_v<OpT, acc::ReductionOp>)
619 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
621 return op->emitOpError() <<
"recipe expected for " << operandName;
626 return op->emitOpError()
627 <<
"expected symbol reference " << operandRecipe <<
" to point to a "
628 << operandName <<
" declaration";
649 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
670 if (failed(parser.
parseType(accVarType)))
680 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
692 mlir::TypeAttr &varTypeAttr) {
693 if (failed(parser.
parseType(varPtrType)))
704 varTypeAttr = mlir::TypeAttr::get(varType);
709 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
710 varTypeAttr = mlir::TypeAttr::get(
711 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
713 varTypeAttr = mlir::TypeAttr::get(varPtrType);
720 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
728 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
729 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
731 if (typeToCheckAgainst != varType) {
739 mlir::SymbolRefAttr &recipeAttr) {
746 mlir::SymbolRefAttr recipeAttr) {
753LogicalResult acc::DataBoundsOp::verify() {
754 auto extent = getExtent();
755 auto upperbound = getUpperbound();
756 if (!extent && !upperbound)
757 return emitError(
"expected extent or upperbound.");
764LogicalResult acc::PrivateOp::verify() {
767 "data clause associated with private operation must match its intent");
781LogicalResult acc::FirstprivateOp::verify() {
783 return emitError(
"data clause associated with firstprivate operation must "
790 *
this,
"firstprivate")))
798LogicalResult acc::FirstprivateMapInitialOp::verify() {
800 return emitError(
"data clause associated with firstprivate operation must "
812LogicalResult acc::ReductionOp::verify() {
814 return emitError(
"data clause associated with reduction operation must "
821 *
this,
"reduction")))
829LogicalResult acc::DevicePtrOp::verify() {
831 return emitError(
"data clause associated with deviceptr operation must "
845LogicalResult acc::PresentOp::verify() {
848 "data clause associated with present operation must match its intent");
861LogicalResult acc::CopyinOp::verify() {
863 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
868 "data clause associated with copyin operation must match its intent"
869 " or specify original clause this operation was decomposed from");
875 acc::DataClauseModifier::always |
876 acc::DataClauseModifier::capture)))
881bool acc::CopyinOp::isCopyinReadonly() {
882 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
883 acc::bitEnumContainsAny(getModifiers(),
884 acc::DataClauseModifier::readonly);
890LogicalResult acc::CreateOp::verify() {
897 "data clause associated with create operation must match its intent"
898 " or specify original clause this operation was decomposed from");
906 acc::DataClauseModifier::always |
907 acc::DataClauseModifier::capture)))
912bool acc::CreateOp::isCreateZero() {
914 return getDataClause() == acc::DataClause::acc_create_zero ||
916 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
922LogicalResult acc::NoCreateOp::verify() {
924 return emitError(
"data clause associated with no_create operation must "
938LogicalResult acc::AttachOp::verify() {
941 "data clause associated with attach operation must match its intent");
955LogicalResult acc::DeclareDeviceResidentOp::verify() {
956 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
957 return emitError(
"data clause associated with device_resident operation "
958 "must match its intent");
972LogicalResult acc::DeclareLinkOp::verify() {
975 "data clause associated with link operation must match its intent");
988LogicalResult acc::CopyoutOp::verify() {
995 "data clause associated with copyout operation must match its intent"
996 " or specify original clause this operation was decomposed from");
998 return emitError(
"must have both host and device pointers");
1004 acc::DataClauseModifier::always |
1005 acc::DataClauseModifier::capture)))
1010bool acc::CopyoutOp::isCopyoutZero() {
1011 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1012 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1018LogicalResult acc::DeleteOp::verify() {
1027 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1030 "data clause associated with delete operation must match its intent"
1031 " or specify original clause this operation was decomposed from");
1033 return emitError(
"must have device pointer");
1037 acc::DataClauseModifier::readonly |
1038 acc::DataClauseModifier::always |
1039 acc::DataClauseModifier::capture)))
1047LogicalResult acc::DetachOp::verify() {
1052 "data clause associated with detach operation must match its intent"
1053 " or specify original clause this operation was decomposed from");
1055 return emitError(
"must have device pointer");
1064LogicalResult acc::UpdateHostOp::verify() {
1069 "data clause associated with host operation must match its intent"
1070 " or specify original clause this operation was decomposed from");
1072 return emitError(
"must have both host and device pointers");
1085LogicalResult acc::UpdateDeviceOp::verify() {
1089 "data clause associated with device operation must match its intent"
1090 " or specify original clause this operation was decomposed from");
1103LogicalResult acc::UseDeviceOp::verify() {
1107 "data clause associated with use_device operation must match its intent"
1108 " or specify original clause this operation was decomposed from");
1121LogicalResult acc::CacheOp::verify() {
1126 "data clause associated with cache operation must match its intent"
1127 " or specify original clause this operation was decomposed from");
1137bool acc::CacheOp::isCacheReadonly() {
1138 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1139 acc::bitEnumContainsAny(getModifiers(),
1140 acc::DataClauseModifier::readonly);
1143template <
typename StructureOp>
1145 unsigned nRegions = 1) {
1148 for (
unsigned i = 0; i < nRegions; ++i)
1151 for (
Region *region : regions)
1159 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1166template <
typename OpTy>
1168 using OpRewritePattern<OpTy>::OpRewritePattern;
1170 LogicalResult matchAndRewrite(OpTy op,
1171 PatternRewriter &rewriter)
const override {
1173 Value ifCond = op.getIfCond();
1177 IntegerAttr constAttr;
1180 if (constAttr.getInt())
1181 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1193 assert(region.
hasOneBlock() &&
"expected single-block region");
1205template <
typename OpTy>
1206struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1207 using OpRewritePattern<OpTy>::OpRewritePattern;
1209 LogicalResult matchAndRewrite(OpTy op,
1210 PatternRewriter &rewriter)
const override {
1212 Value ifCond = op.getIfCond();
1216 IntegerAttr constAttr;
1219 if (constAttr.getInt())
1220 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1230struct RemoveEmptyKernelEnvironment
1232 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1234 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1235 PatternRewriter &rewriter)
const override {
1236 assert(op->getNumRegions() == 1 &&
"expected op to have one region");
1247 if (
auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1248 for (
auto attr : deviceTypeAttr) {
1249 if (
auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1250 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1257 if (
auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1258 for (
auto attr : hasDevnumAttr) {
1259 if (
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1260 if (boolAttr.getValue())
1267 if (
auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1268 if (segmentsAttr.size() > 1)
1274 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1301 for (
Value bound : bounds) {
1302 argTypes.push_back(bound.getType());
1303 argLocs.push_back(loc);
1310 Value privatizedValue;
1316 if (isa<MappableType>(varType)) {
1317 auto mappableTy = cast<MappableType>(varType);
1318 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1319 privatizedValue = mappableTy.generatePrivateInit(
1320 builder, loc, typedVar, varName, bounds, {}, needsFree);
1321 if (!privatizedValue)
1324 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1325 auto pointerLikeTy = cast<PointerLikeType>(varType);
1327 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1328 blockArgVar, needsFree);
1329 if (!privatizedValue)
1334 acc::YieldOp::create(builder, loc, privatizedValue);
1349 for (
Value bound : bounds) {
1350 copyArgTypes.push_back(bound.getType());
1351 copyArgLocs.push_back(loc);
1358 bool isMappable = isa<MappableType>(varType);
1359 bool isPointerLike = isa<PointerLikeType>(varType);
1362 if (isMappable && !isPointerLike)
1366 if (isPointerLike) {
1367 auto pointerLikeTy = cast<PointerLikeType>(varType);
1372 if (!pointerLikeTy.genCopy(
1379 acc::TerminatorOp::create(builder, loc);
1393 for (
Value bound : bounds) {
1394 destroyArgTypes.push_back(bound.getType());
1395 destroyArgLocs.push_back(loc);
1399 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1403 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1404 if (isa<MappableType>(varType)) {
1405 auto mappableTy = cast<MappableType>(varType);
1406 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
1409 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1410 auto pointerLikeTy = cast<PointerLikeType>(varType);
1411 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1415 acc::TerminatorOp::create(builder, loc);
1426 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1428 if (optional && region.
empty())
1432 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1436 return op->
emitOpError() <<
"expects " << regionName
1439 << regionType <<
" type";
1442 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1443 if (yieldOp.getOperands().size() != 1 ||
1444 yieldOp.getOperands().getTypes()[0] != type)
1445 return op->
emitOpError() <<
"expects " << regionName
1447 "yield a value of the "
1448 << regionType <<
" type";
1454LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1456 "privatization",
"init",
getType(),
1460 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1466std::optional<PrivateRecipeOp>
1468 StringRef recipeName,
Type varType,
1471 bool isMappable = isa<MappableType>(varType);
1472 bool isPointerLike = isa<PointerLikeType>(varType);
1475 if (!isMappable && !isPointerLike)
1476 return std::nullopt;
1481 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1484 bool needsFree =
false;
1485 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1486 varName, bounds, needsFree))) {
1488 return std::nullopt;
1495 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1496 Value allocRes = yieldOp.getOperand(0);
1498 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1499 varType, allocRes, bounds))) {
1501 return std::nullopt;
1508std::optional<PrivateRecipeOp>
1510 StringRef recipeName,
1511 FirstprivateRecipeOp firstprivRecipe) {
1514 auto varType = firstprivRecipe.getType();
1515 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1519 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1522 if (!firstprivRecipe.getDestroyRegion().empty()) {
1524 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1534LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1536 "privatization",
"init",
getType(),
1540 if (getCopyRegion().empty())
1541 return emitOpError() <<
"expects non-empty copy region";
1546 return emitOpError() <<
"expects copy region with two arguments of the "
1547 "privatization type";
1549 if (getDestroyRegion().empty())
1553 "privatization",
"destroy",
1560std::optional<FirstprivateRecipeOp>
1562 StringRef recipeName,
Type varType,
1565 bool isMappable = isa<MappableType>(varType);
1566 bool isPointerLike = isa<PointerLikeType>(varType);
1569 if (!isMappable && !isPointerLike)
1570 return std::nullopt;
1575 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1578 bool needsFree =
false;
1579 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1580 varName, bounds, needsFree))) {
1582 return std::nullopt;
1586 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1589 return std::nullopt;
1596 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1597 Value allocRes = yieldOp.getOperand(0);
1599 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1600 varType, allocRes, bounds))) {
1602 return std::nullopt;
1613LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1619 if (getCombinerRegion().empty())
1620 return emitOpError() <<
"expects non-empty combiner region";
1622 Block &reductionBlock = getCombinerRegion().
front();
1626 return emitOpError() <<
"expects combiner region with the first two "
1627 <<
"arguments of the reduction type";
1629 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1630 if (yieldOp.getOperands().size() != 1 ||
1631 yieldOp.getOperands().getTypes()[0] !=
getType())
1632 return emitOpError() <<
"expects combiner region to yield a value "
1633 "of the reduction type";
1644template <
typename Op>
1648 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1649 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1650 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1651 operand.getDefiningOp()))
1653 "expect data entry/exit operation or acc.getdeviceptr "
1658template <
typename OpT,
typename RecipeOpT>
1661 llvm::StringRef operandName) {
1664 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1666 <<
"expected " << operandName <<
" as defining op";
1667 if (!set.insert(operand).second)
1669 << operandName <<
" operand appears more than once";
1674unsigned ParallelOp::getNumDataOperands() {
1675 return getReductionOperands().size() + getPrivateOperands().size() +
1676 getFirstprivateOperands().size() + getDataClauseOperands().size();
1679Value ParallelOp::getDataOperand(
unsigned i) {
1681 numOptional += getNumGangs().size();
1682 numOptional += getNumWorkers().size();
1683 numOptional += getVectorLength().size();
1684 numOptional += getIfCond() ? 1 : 0;
1685 numOptional += getSelfCond() ? 1 : 0;
1686 return getOperand(getWaitOperands().size() + numOptional + i);
1689template <
typename Op>
1692 llvm::StringRef keyword) {
1693 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1694 return op.
emitOpError() << keyword <<
" operands count must match "
1695 << keyword <<
" device_type count";
1699template <
typename Op>
1702 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1703 std::size_t numOperandsInSegments = 0;
1704 std::size_t nbOfSegments = 0;
1707 for (
auto segCount : segments.
asArrayRef()) {
1708 if (maxInSegment != 0 && segCount > maxInSegment)
1709 return op.
emitOpError() << keyword <<
" expects a maximum of "
1710 << maxInSegment <<
" values per segment";
1711 numOperandsInSegments += segCount;
1716 if ((numOperandsInSegments != operands.size()) ||
1717 (!deviceTypes && !operands.empty()))
1719 << keyword <<
" operand count does not match count in segments";
1720 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1722 << keyword <<
" segment count does not match device_type count";
1726LogicalResult acc::ParallelOp::verify() {
1728 mlir::acc::PrivateRecipeOp>(
1729 *
this, getPrivateOperands(),
"private")))
1732 mlir::acc::FirstprivateRecipeOp>(
1733 *
this, getFirstprivateOperands(),
"firstprivate")))
1736 mlir::acc::ReductionRecipeOp>(
1737 *
this, getReductionOperands(),
"reduction")))
1741 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1742 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1746 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1747 getWaitOperandsDeviceTypeAttr(),
"wait")))
1751 getNumWorkersDeviceTypeAttr(),
1756 getVectorLengthDeviceTypeAttr(),
1761 getAsyncOperandsDeviceTypeAttr(),
1774 mlir::acc::DeviceType deviceType) {
1777 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1782bool acc::ParallelOp::hasAsyncOnly() {
1783 return hasAsyncOnly(mlir::acc::DeviceType::None);
1786bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1791 return getAsyncValue(mlir::acc::DeviceType::None);
1794mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1799mlir::Value acc::ParallelOp::getNumWorkersValue() {
1800 return getNumWorkersValue(mlir::acc::DeviceType::None);
1804acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1809mlir::Value acc::ParallelOp::getVectorLengthValue() {
1810 return getVectorLengthValue(mlir::acc::DeviceType::None);
1814acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1816 getVectorLength(), deviceType);
1820 return getNumGangsValues(mlir::acc::DeviceType::None);
1824ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1826 getNumGangsSegments(), deviceType);
1829bool acc::ParallelOp::hasWaitOnly() {
1830 return hasWaitOnly(mlir::acc::DeviceType::None);
1833bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1838 return getWaitValues(mlir::acc::DeviceType::None);
1842ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1844 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1845 getHasWaitDevnum(), deviceType);
1849 return getWaitDevnum(mlir::acc::DeviceType::None);
1852mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1854 getWaitOperandsSegments(), getHasWaitDevnum(),
1869 odsBuilder, odsState, asyncOperands,
nullptr,
1870 nullptr, waitOperands,
nullptr,
1872 nullptr, numGangs,
nullptr,
1873 nullptr, numWorkers,
1874 nullptr, vectorLength,
1875 nullptr, ifCond, selfCond,
1876 nullptr, reductionOperands, gangPrivateOperands,
1877 gangFirstPrivateOperands, dataClauseOperands,
1881void acc::ParallelOp::addNumWorkersOperand(
1884 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1885 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1886 getNumWorkersMutable()));
1888void acc::ParallelOp::addVectorLengthOperand(
1891 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1892 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1893 getVectorLengthMutable()));
1896void acc::ParallelOp::addAsyncOnly(
1898 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1899 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1902void acc::ParallelOp::addAsyncOperand(
1905 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1906 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1907 getAsyncOperandsMutable()));
1910void acc::ParallelOp::addNumGangsOperands(
1914 if (getNumGangsSegments())
1915 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1917 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1918 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1919 getNumGangsMutable(), segments));
1921 setNumGangsSegments(segments);
1923void acc::ParallelOp::addWaitOnly(
1925 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1926 effectiveDeviceTypes));
1928void acc::ParallelOp::addWaitOperands(
1933 if (getWaitOperandsSegments())
1934 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1936 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1937 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1938 getWaitOperandsMutable(), segments));
1939 setWaitOperandsSegments(segments);
1942 if (getHasWaitDevnumAttr())
1943 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1946 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1948 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1951void acc::ParallelOp::addPrivatization(
MLIRContext *context,
1952 mlir::acc::PrivateOp op,
1953 mlir::acc::PrivateRecipeOp recipe) {
1954 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1955 getPrivateOperandsMutable().append(op.getResult());
1958void acc::ParallelOp::addFirstPrivatization(
1959 MLIRContext *context, mlir::acc::FirstprivateOp op,
1960 mlir::acc::FirstprivateRecipeOp recipe) {
1961 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1962 getFirstprivateOperandsMutable().append(op.getResult());
1965void acc::ParallelOp::addReduction(
MLIRContext *context,
1966 mlir::acc::ReductionOp op,
1967 mlir::acc::ReductionRecipeOp recipe) {
1968 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1969 getReductionOperandsMutable().append(op.getResult());
1984 int32_t crtOperandsSize = operands.size();
1987 if (parser.parseOperand(operands.emplace_back()) ||
1988 parser.parseColonType(types.emplace_back()))
1993 seg.push_back(operands.size() - crtOperandsSize);
2003 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2004 parser.
getContext(), mlir::acc::DeviceType::None));
2010 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2017 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2018 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2019 p <<
" [" << attr <<
"]";
2024 std::optional<mlir::ArrayAttr> deviceTypes,
2025 std::optional<mlir::DenseI32ArrayAttr> segments) {
2027 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2029 llvm::interleaveComma(
2030 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2031 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2051 int32_t crtOperandsSize = operands.size();
2055 if (parser.parseOperand(operands.emplace_back()) ||
2056 parser.parseColonType(types.emplace_back()))
2062 seg.push_back(operands.size() - crtOperandsSize);
2072 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2073 parser.
getContext(), mlir::acc::DeviceType::None));
2079 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2088 std::optional<mlir::DenseI32ArrayAttr> segments) {
2090 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2092 llvm::interleaveComma(
2093 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2094 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2107 mlir::ArrayAttr &keywordOnly) {
2111 bool needCommaBeforeOperands =
false;
2115 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2116 parser.
getContext(), mlir::acc::DeviceType::None));
2117 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2124 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2131 needCommaBeforeOperands =
true;
2134 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2141 int32_t crtOperandsSize = operands.size();
2153 if (parser.parseOperand(operands.emplace_back()) ||
2154 parser.parseColonType(types.emplace_back()))
2160 seg.push_back(operands.size() - crtOperandsSize);
2170 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2171 parser.
getContext(), mlir::acc::DeviceType::None));
2178 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2179 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2181 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2189 if (attrs->size() != 1)
2191 if (
auto deviceTypeAttr =
2192 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2193 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2199 std::optional<mlir::ArrayAttr> deviceTypes,
2200 std::optional<mlir::DenseI32ArrayAttr> segments,
2201 std::optional<mlir::ArrayAttr> hasDevNum,
2202 std::optional<mlir::ArrayAttr> keywordOnly) {
2215 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2217 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2218 if (boolAttr && boolAttr.getValue())
2220 llvm::interleaveComma(
2221 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2222 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2239 if (parser.parseOperand(operands.emplace_back()) ||
2240 parser.parseColonType(types.emplace_back()))
2242 if (succeeded(parser.parseOptionalLSquare())) {
2243 if (parser.parseAttribute(attributes.emplace_back()) ||
2244 parser.parseRSquare())
2247 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2248 parser.getContext(), mlir::acc::DeviceType::None));
2255 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2262 std::optional<mlir::ArrayAttr> deviceTypes) {
2265 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2266 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2275 mlir::ArrayAttr &keywordOnlyDeviceType) {
2278 bool needCommaBeforeOperands =
false;
2282 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2283 parser.
getContext(), mlir::acc::DeviceType::None));
2284 keywordOnlyDeviceType =
2285 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2293 if (parser.parseAttribute(
2294 keywordOnlyDeviceTypeAttributes.emplace_back()))
2301 needCommaBeforeOperands =
true;
2304 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2309 if (parser.parseOperand(operands.emplace_back()) ||
2310 parser.parseColonType(types.emplace_back()))
2312 if (succeeded(parser.parseOptionalLSquare())) {
2313 if (parser.parseAttribute(attributes.emplace_back()) ||
2314 parser.parseRSquare())
2317 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2318 parser.getContext(), mlir::acc::DeviceType::None));
2324 if (
failed(parser.parseRParen()))
2329 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2336 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2338 if (operands.begin() == operands.end() &&
2354 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2355 mlir::Type &operandType, mlir::UnitAttr &attr) {
2358 attr = mlir::UnitAttr::get(parser.
getContext());
2368 if (failed(parser.
parseType(operandType)))
2378 std::optional<mlir::Value> operand,
2380 mlir::UnitAttr attr) {
2397 attr = mlir::UnitAttr::get(parser.
getContext());
2402 if (parser.parseOperand(operands.emplace_back()))
2410 if (parser.parseType(types.emplace_back()))
2425 mlir::UnitAttr attr) {
2430 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2432 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2438 mlir::acc::CombinedConstructsTypeAttr &attr) {
2440 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2441 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2443 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2444 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2446 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2447 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2450 "expected compute construct name");
2458 mlir::acc::CombinedConstructsTypeAttr attr) {
2460 switch (attr.getValue()) {
2461 case mlir::acc::CombinedConstructsType::KernelsLoop:
2464 case mlir::acc::CombinedConstructsType::ParallelLoop:
2467 case mlir::acc::CombinedConstructsType::SerialLoop:
2478unsigned SerialOp::getNumDataOperands() {
2479 return getReductionOperands().size() + getPrivateOperands().size() +
2480 getFirstprivateOperands().size() + getDataClauseOperands().size();
2483Value SerialOp::getDataOperand(
unsigned i) {
2485 numOptional += getIfCond() ? 1 : 0;
2486 numOptional += getSelfCond() ? 1 : 0;
2487 return getOperand(getWaitOperands().size() + numOptional + i);
2490bool acc::SerialOp::hasAsyncOnly() {
2491 return hasAsyncOnly(mlir::acc::DeviceType::None);
2494bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2499 return getAsyncValue(mlir::acc::DeviceType::None);
2502mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2507bool acc::SerialOp::hasWaitOnly() {
2508 return hasWaitOnly(mlir::acc::DeviceType::None);
2511bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2516 return getWaitValues(mlir::acc::DeviceType::None);
2520SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2522 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2523 getHasWaitDevnum(), deviceType);
2527 return getWaitDevnum(mlir::acc::DeviceType::None);
2530mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2532 getWaitOperandsSegments(), getHasWaitDevnum(),
2536LogicalResult acc::SerialOp::verify() {
2538 mlir::acc::PrivateRecipeOp>(
2539 *
this, getPrivateOperands(),
"private")))
2542 mlir::acc::FirstprivateRecipeOp>(
2543 *
this, getFirstprivateOperands(),
"firstprivate")))
2546 mlir::acc::ReductionRecipeOp>(
2547 *
this, getReductionOperands(),
"reduction")))
2551 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2552 getWaitOperandsDeviceTypeAttr(),
"wait")))
2556 getAsyncOperandsDeviceTypeAttr(),
2566void acc::SerialOp::addAsyncOnly(
2568 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2569 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2572void acc::SerialOp::addAsyncOperand(
2575 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2576 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2577 getAsyncOperandsMutable()));
2580void acc::SerialOp::addWaitOnly(
2582 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2583 effectiveDeviceTypes));
2585void acc::SerialOp::addWaitOperands(
2590 if (getWaitOperandsSegments())
2591 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2593 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2594 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2595 getWaitOperandsMutable(), segments));
2596 setWaitOperandsSegments(segments);
2599 if (getHasWaitDevnumAttr())
2600 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2603 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2605 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2608void acc::SerialOp::addPrivatization(
MLIRContext *context,
2609 mlir::acc::PrivateOp op,
2610 mlir::acc::PrivateRecipeOp recipe) {
2611 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2612 getPrivateOperandsMutable().append(op.getResult());
2615void acc::SerialOp::addFirstPrivatization(
2616 MLIRContext *context, mlir::acc::FirstprivateOp op,
2617 mlir::acc::FirstprivateRecipeOp recipe) {
2618 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2619 getFirstprivateOperandsMutable().append(op.getResult());
2622void acc::SerialOp::addReduction(
MLIRContext *context,
2623 mlir::acc::ReductionOp op,
2624 mlir::acc::ReductionRecipeOp recipe) {
2625 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2626 getReductionOperandsMutable().append(op.getResult());
2633unsigned KernelsOp::getNumDataOperands() {
2634 return getDataClauseOperands().size();
2637Value KernelsOp::getDataOperand(
unsigned i) {
2639 numOptional += getWaitOperands().size();
2640 numOptional += getNumGangs().size();
2641 numOptional += getNumWorkers().size();
2642 numOptional += getVectorLength().size();
2643 numOptional += getIfCond() ? 1 : 0;
2644 numOptional += getSelfCond() ? 1 : 0;
2645 return getOperand(numOptional + i);
2648bool acc::KernelsOp::hasAsyncOnly() {
2649 return hasAsyncOnly(mlir::acc::DeviceType::None);
2652bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2657 return getAsyncValue(mlir::acc::DeviceType::None);
2660mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2666 return getNumWorkersValue(mlir::acc::DeviceType::None);
2670acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2675mlir::Value acc::KernelsOp::getVectorLengthValue() {
2676 return getVectorLengthValue(mlir::acc::DeviceType::None);
2680acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2682 getVectorLength(), deviceType);
2686 return getNumGangsValues(mlir::acc::DeviceType::None);
2690KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2692 getNumGangsSegments(), deviceType);
2695bool acc::KernelsOp::hasWaitOnly() {
2696 return hasWaitOnly(mlir::acc::DeviceType::None);
2699bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2704 return getWaitValues(mlir::acc::DeviceType::None);
2708KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2710 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2711 getHasWaitDevnum(), deviceType);
2715 return getWaitDevnum(mlir::acc::DeviceType::None);
2718mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2720 getWaitOperandsSegments(), getHasWaitDevnum(),
2724LogicalResult acc::KernelsOp::verify() {
2726 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2727 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2731 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2732 getWaitOperandsDeviceTypeAttr(),
"wait")))
2736 getNumWorkersDeviceTypeAttr(),
2741 getVectorLengthDeviceTypeAttr(),
2746 getAsyncOperandsDeviceTypeAttr(),
2756void acc::KernelsOp::addPrivatization(
MLIRContext *context,
2757 mlir::acc::PrivateOp op,
2758 mlir::acc::PrivateRecipeOp recipe) {
2759 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2760 getPrivateOperandsMutable().append(op.getResult());
2763void acc::KernelsOp::addFirstPrivatization(
2764 MLIRContext *context, mlir::acc::FirstprivateOp op,
2765 mlir::acc::FirstprivateRecipeOp recipe) {
2766 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2767 getFirstprivateOperandsMutable().append(op.getResult());
2770void acc::KernelsOp::addReduction(
MLIRContext *context,
2771 mlir::acc::ReductionOp op,
2772 mlir::acc::ReductionRecipeOp recipe) {
2773 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2774 getReductionOperandsMutable().append(op.getResult());
2777void acc::KernelsOp::addNumWorkersOperand(
2780 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2781 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2782 getNumWorkersMutable()));
2785void acc::KernelsOp::addVectorLengthOperand(
2788 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2789 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2790 getVectorLengthMutable()));
2792void acc::KernelsOp::addAsyncOnly(
2794 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2795 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2798void acc::KernelsOp::addAsyncOperand(
2801 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2802 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2803 getAsyncOperandsMutable()));
2806void acc::KernelsOp::addNumGangsOperands(
2810 if (getNumGangsSegmentsAttr())
2811 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2813 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2814 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2815 getNumGangsMutable(), segments));
2817 setNumGangsSegments(segments);
2820void acc::KernelsOp::addWaitOnly(
2822 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2823 effectiveDeviceTypes));
2825void acc::KernelsOp::addWaitOperands(
2830 if (getWaitOperandsSegments())
2831 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2833 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2834 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2835 getWaitOperandsMutable(), segments));
2836 setWaitOperandsSegments(segments);
2839 if (getHasWaitDevnumAttr())
2840 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2843 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2845 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2852LogicalResult acc::HostDataOp::verify() {
2853 if (getDataClauseOperands().empty())
2854 return emitError(
"at least one operand must appear on the host_data "
2857 for (
mlir::Value operand : getDataClauseOperands())
2858 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2859 return emitError(
"expect data entry operation as defining op");
2865 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2872void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2874 results.
add<RemoveEmptyKernelEnvironment>(context);
2886 bool &needCommaBetweenValues,
bool &newValue) {
2893 attributes.push_back(gangArgType);
2894 needCommaBetweenValues =
true;
2905 mlir::ArrayAttr &gangOnlyDeviceType) {
2910 bool needCommaBetweenValues =
false;
2911 bool needCommaBeforeOperands =
false;
2915 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2916 parser.
getContext(), mlir::acc::DeviceType::None));
2917 gangOnlyDeviceType =
2918 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
2926 if (parser.parseAttribute(
2927 gangOnlyDeviceTypeAttributes.emplace_back()))
2934 needCommaBeforeOperands =
true;
2937 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
2938 mlir::acc::GangArgType::Num);
2939 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
2940 mlir::acc::GangArgType::Dim);
2941 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2942 parser.
getContext(), mlir::acc::GangArgType::Static);
2945 if (needCommaBeforeOperands) {
2946 needCommaBeforeOperands =
false;
2953 int32_t crtOperandsSize = gangOperands.size();
2955 bool newValue =
false;
2956 bool needValue =
false;
2957 if (needCommaBetweenValues) {
2965 gangOperands, gangOperandsType,
2966 gangArgTypeAttributes, argNum,
2967 needCommaBetweenValues, newValue)))
2970 gangOperands, gangOperandsType,
2971 gangArgTypeAttributes, argDim,
2972 needCommaBetweenValues, newValue)))
2974 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2975 gangOperands, gangOperandsType,
2976 gangArgTypeAttributes, argStatic,
2977 needCommaBetweenValues, newValue)))
2980 if (!newValue && needValue) {
2982 "new value expected after comma");
2990 if (gangOperands.empty())
2993 "expect at least one of num, dim or static values");
2999 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
3003 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3004 parser.
getContext(), mlir::acc::DeviceType::None));
3007 seg.push_back(gangOperands.size() - crtOperandsSize);
3015 gangArgTypeAttributes.end());
3016 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
3017 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
3020 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3021 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
3029 std::optional<mlir::ArrayAttr> gangArgTypes,
3030 std::optional<mlir::ArrayAttr> deviceTypes,
3031 std::optional<mlir::DenseI32ArrayAttr> segments,
3032 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3034 if (operands.begin() == operands.end() &&
3049 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
3051 llvm::interleaveComma(
3052 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
3053 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3054 (*gangArgTypes)[opIdx]);
3055 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3056 p << LoopOp::getGangNumKeyword();
3057 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3058 p << LoopOp::getGangDimKeyword();
3059 else if (gangArgTypeAttr.getValue() ==
3060 mlir::acc::GangArgType::Static)
3061 p << LoopOp::getGangStaticKeyword();
3062 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
3073 std::optional<mlir::ArrayAttr> segments,
3074 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3077 for (
auto attr : *segments) {
3078 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3079 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3087static std::optional<mlir::acc::DeviceType>
3089 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3091 return std::nullopt;
3092 for (
auto attr : deviceTypes) {
3093 auto deviceTypeAttr =
3094 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3095 if (!deviceTypeAttr)
3096 return mlir::acc::DeviceType::None;
3097 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3098 return deviceTypeAttr.getValue();
3100 return std::nullopt;
3103LogicalResult acc::LoopOp::verify() {
3104 if (getUpperbound().size() != getStep().size())
3105 return emitError() <<
"number of upperbounds expected to be the same as "
3108 if (getUpperbound().size() != getLowerbound().size())
3109 return emitError() <<
"number of upperbounds expected to be the same as "
3110 "number of lowerbounds";
3112 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3113 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3114 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3115 <<
" as upperbound size";
3118 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3119 return emitOpError() <<
"collapse device_type attr must be define when"
3120 <<
" collapse attr is present";
3122 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3123 getCollapseAttr().getValue().size() !=
3124 getCollapseDeviceTypeAttr().getValue().size())
3125 return emitOpError() <<
"collapse attribute count must match collapse"
3126 <<
" device_type count";
3127 if (
auto duplicateDeviceType =
checkDeviceTypes(getCollapseDeviceTypeAttr()))
3129 << acc::stringifyDeviceType(*duplicateDeviceType)
3130 <<
"` found in collapseDeviceType attribute";
3133 if (!getGangOperands().empty()) {
3134 if (!getGangOperandsArgType())
3135 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3136 <<
" when gang operands are present";
3138 if (getGangOperands().size() !=
3139 getGangOperandsArgTypeAttr().getValue().size())
3140 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3141 <<
" gangOperands count";
3143 if (getGangAttr()) {
3146 << acc::stringifyDeviceType(*duplicateDeviceType)
3147 <<
"` found in gang attribute";
3151 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3152 getGangOperandsDeviceTypeAttr(),
"gang")))
3158 << acc::stringifyDeviceType(*duplicateDeviceType)
3159 <<
"` found in worker attribute";
3160 if (
auto duplicateDeviceType =
3163 << acc::stringifyDeviceType(*duplicateDeviceType)
3164 <<
"` found in workerNumOperandsDeviceType attribute";
3166 getWorkerNumOperandsDeviceTypeAttr(),
3173 << acc::stringifyDeviceType(*duplicateDeviceType)
3174 <<
"` found in vector attribute";
3175 if (
auto duplicateDeviceType =
3178 << acc::stringifyDeviceType(*duplicateDeviceType)
3179 <<
"` found in vectorOperandsDeviceType attribute";
3181 getVectorOperandsDeviceTypeAttr(),
3186 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3187 getTileOperandsDeviceTypeAttr(),
"tile")))
3191 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3195 return emitError() <<
"only one of auto, independent, seq can be present "
3201 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3202 return attr.getValue() == mlir::acc::DeviceType::None;
3204 bool hasDefaultSeq =
3206 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3209 bool hasDefaultIndependent =
3210 getIndependentAttr()
3212 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3215 bool hasDefaultAuto =
3217 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3220 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3222 <<
"at least one of auto, independent, seq must be present";
3227 for (
auto attr : getSeqAttr()) {
3228 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3229 if (hasVector(deviceTypeAttr.getValue()) ||
3230 getVectorValue(deviceTypeAttr.getValue()) ||
3231 hasWorker(deviceTypeAttr.getValue()) ||
3232 getWorkerValue(deviceTypeAttr.getValue()) ||
3233 hasGang(deviceTypeAttr.getValue()) ||
3234 getGangValue(mlir::acc::GangArgType::Num,
3235 deviceTypeAttr.getValue()) ||
3236 getGangValue(mlir::acc::GangArgType::Dim,
3237 deviceTypeAttr.getValue()) ||
3238 getGangValue(mlir::acc::GangArgType::Static,
3239 deviceTypeAttr.getValue()))
3240 return emitError() <<
"gang, worker or vector cannot appear with seq";
3245 mlir::acc::PrivateRecipeOp>(
3246 *
this, getPrivateOperands(),
"private")))
3250 mlir::acc::FirstprivateRecipeOp>(
3251 *
this, getFirstprivateOperands(),
"firstprivate")))
3255 mlir::acc::ReductionRecipeOp>(
3256 *
this, getReductionOperands(),
"reduction")))
3259 if (getCombined().has_value() &&
3260 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3261 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3262 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3263 return emitError(
"unexpected combined constructs attribute");
3267 if (getRegion().empty())
3268 return emitError(
"expected non-empty body.");
3270 if (getUnstructured()) {
3271 if (!isContainerLike())
3273 "unstructured acc.loop must not have induction variables");
3274 }
else if (isContainerLike()) {
3278 uint64_t collapseCount = getCollapseValue().value_or(1);
3279 if (getCollapseAttr()) {
3280 for (
auto collapseEntry : getCollapseAttr()) {
3281 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3282 if (intAttr.getValue().getZExtValue() > collapseCount)
3283 collapseCount = intAttr.getValue().getZExtValue();
3291 bool foundSibling =
false;
3293 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3295 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3297 foundSibling =
true;
3302 expectedParent = op;
3305 if (collapseCount == 0)
3311 return emitError(
"found sibling loops inside container-like acc.loop");
3312 if (collapseCount != 0)
3313 return emitError(
"failed to find enough loop-like operations inside "
3314 "container-like acc.loop");
3320unsigned LoopOp::getNumDataOperands() {
3321 return getReductionOperands().size() + getPrivateOperands().size() +
3322 getFirstprivateOperands().size();
3325Value LoopOp::getDataOperand(
unsigned i) {
3326 unsigned numOptional =
3327 getLowerbound().size() + getUpperbound().size() + getStep().size();
3328 numOptional += getGangOperands().size();
3329 numOptional += getVectorOperands().size();
3330 numOptional += getWorkerNumOperands().size();
3331 numOptional += getTileOperands().size();
3332 numOptional += getCacheOperands().size();
3333 return getOperand(numOptional + i);
3336bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3338bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3342bool LoopOp::hasIndependent() {
3343 return hasIndependent(mlir::acc::DeviceType::None);
3346bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3350bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3352bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3357 return getVectorValue(mlir::acc::DeviceType::None);
3360mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3362 getVectorOperands(), deviceType);
3365bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3367bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3372 return getWorkerValue(mlir::acc::DeviceType::None);
3375mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3377 getWorkerNumOperands(), deviceType);
3380bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3382bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3387 return getTileValues(mlir::acc::DeviceType::None);
3391LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3393 getTileOperandsSegments(), deviceType);
3396std::optional<int64_t> LoopOp::getCollapseValue() {
3397 return getCollapseValue(mlir::acc::DeviceType::None);
3400std::optional<int64_t>
3401LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3402 if (!getCollapseAttr())
3403 return std::nullopt;
3404 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3406 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3407 return intAttr.getValue().getZExtValue();
3409 return std::nullopt;
3412mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3413 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3416mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3417 mlir::acc::DeviceType deviceType) {
3418 if (getGangOperands().empty())
3420 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3421 int32_t nbOperandsBefore = 0;
3422 for (
unsigned i = 0; i < *pos; ++i)
3423 nbOperandsBefore += (*getGangOperandsSegments())[i];
3426 .drop_front(nbOperandsBefore)
3427 .take_front((*getGangOperandsSegments())[*pos]);
3429 int32_t argTypeIdx = nbOperandsBefore;
3430 for (
auto value : values) {
3431 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3432 (*getGangOperandsArgType())[argTypeIdx]);
3433 if (gangArgTypeAttr.getValue() == gangArgType)
3441bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3443bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3448 return {&getRegion()};
3492 if (!regionArgs.empty()) {
3493 p << acc::LoopOp::getControlKeyword() <<
"(";
3494 llvm::interleaveComma(regionArgs, p,
3496 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3497 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3498 <<
" : " << stepType <<
") ";
3505 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3506 effectiveDeviceTypes));
3509void acc::LoopOp::addIndependent(
3511 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3512 context, getIndependentAttr(), effectiveDeviceTypes));
3517 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3518 effectiveDeviceTypes));
3521void acc::LoopOp::setCollapseForDeviceTypes(
3523 llvm::APInt value) {
3527 assert((getCollapseAttr() ==
nullptr) ==
3528 (getCollapseDeviceTypeAttr() ==
nullptr));
3529 assert(value.getBitWidth() == 64);
3531 if (getCollapseAttr()) {
3532 for (
const auto &existing :
3533 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3534 newValues.push_back(std::get<0>(existing));
3535 newDeviceTypes.push_back(std::get<1>(existing));
3539 if (effectiveDeviceTypes.empty()) {
3542 newValues.push_back(
3543 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3544 newDeviceTypes.push_back(
3545 acc::DeviceTypeAttr::get(context, DeviceType::None));
3547 for (DeviceType dt : effectiveDeviceTypes) {
3548 newValues.push_back(
3549 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3550 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3554 setCollapseAttr(ArrayAttr::get(context, newValues));
3555 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3558void acc::LoopOp::setTileForDeviceTypes(
3562 if (getTileOperandsSegments())
3563 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3565 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3566 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3567 getTileOperandsMutable(), segments));
3569 setTileOperandsSegments(segments);
3572void acc::LoopOp::addVectorOperand(
3575 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3576 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3577 newValue, getVectorOperandsMutable()));
3580void acc::LoopOp::addEmptyVector(
3582 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3583 effectiveDeviceTypes));
3586void acc::LoopOp::addWorkerNumOperand(
3589 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3590 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3591 newValue, getWorkerNumOperandsMutable()));
3594void acc::LoopOp::addEmptyWorker(
3596 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3597 effectiveDeviceTypes));
3600void acc::LoopOp::addEmptyGang(
3602 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3603 effectiveDeviceTypes));
3606bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3607 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3608 return attr.getValue() == dt;
3610 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3611 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3614 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3616 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3618 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3624bool acc::LoopOp::hasDefaultGangWorkerVector() {
3625 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3626 hasGang() || getGangValue(GangArgType::Num) ||
3627 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3631acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3632 if (hasSeq(deviceType))
3633 return LoopParMode::loop_seq;
3634 if (hasAuto(deviceType))
3635 return LoopParMode::loop_auto;
3636 if (hasIndependent(deviceType))
3637 return LoopParMode::loop_independent;
3639 return LoopParMode::loop_seq;
3641 return LoopParMode::loop_auto;
3642 assert(hasIndependent() &&
3643 "loop must have default auto, seq, or independent");
3644 return LoopParMode::loop_independent;
3647void acc::LoopOp::addGangOperands(
3652 getGangOperandsSegments())
3653 llvm::copy(*existingSegments, std::back_inserter(segments));
3655 unsigned beforeCount = segments.size();
3657 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3658 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3659 getGangOperandsMutable(), segments));
3661 setGangOperandsSegments(segments);
3668 unsigned numAdded = segments.size() - beforeCount;
3672 if (getGangOperandsArgTypeAttr())
3673 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3675 for (
auto i : llvm::index_range(0u, numAdded)) {
3676 llvm::transform(argTypes, std::back_inserter(gangTypes),
3677 [=](mlir::acc::GangArgType gangTy) {
3678 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3683 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3687void acc::LoopOp::addPrivatization(
MLIRContext *context,
3688 mlir::acc::PrivateOp op,
3689 mlir::acc::PrivateRecipeOp recipe) {
3690 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3691 getPrivateOperandsMutable().append(op.getResult());
3694void acc::LoopOp::addFirstPrivatization(
3695 MLIRContext *context, mlir::acc::FirstprivateOp op,
3696 mlir::acc::FirstprivateRecipeOp recipe) {
3697 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3698 getFirstprivateOperandsMutable().append(op.getResult());
3701void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3702 mlir::acc::ReductionRecipeOp recipe) {
3703 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3704 getReductionOperandsMutable().append(op.getResult());
3711LogicalResult acc::DataOp::verify() {
3716 return emitError(
"at least one operand or the default attribute "
3717 "must appear on the data operation");
3719 for (
mlir::Value operand : getDataClauseOperands())
3720 if (isa<BlockArgument>(operand) ||
3721 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3722 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3723 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3724 operand.getDefiningOp()))
3725 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3734unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3736Value DataOp::getDataOperand(
unsigned i) {
3737 unsigned numOptional = getIfCond() ? 1 : 0;
3739 numOptional += getWaitOperands().size();
3740 return getOperand(numOptional + i);
3743bool acc::DataOp::hasAsyncOnly() {
3744 return hasAsyncOnly(mlir::acc::DeviceType::None);
3747bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3752 return getAsyncValue(mlir::acc::DeviceType::None);
3755mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3760bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
3762bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3767 return getWaitValues(mlir::acc::DeviceType::None);
3771DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3773 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3774 getHasWaitDevnum(), deviceType);
3778 return getWaitDevnum(mlir::acc::DeviceType::None);
3781mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3783 getWaitOperandsSegments(), getHasWaitDevnum(),
3787void acc::DataOp::addAsyncOnly(
3789 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3790 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3793void acc::DataOp::addAsyncOperand(
3796 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3797 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3798 getAsyncOperandsMutable()));
3801void acc::DataOp::addWaitOnly(
MLIRContext *context,
3803 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3804 effectiveDeviceTypes));
3807void acc::DataOp::addWaitOperands(
3812 if (getWaitOperandsSegments())
3813 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3815 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3816 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3817 getWaitOperandsMutable(), segments));
3818 setWaitOperandsSegments(segments);
3821 if (getHasWaitDevnumAttr())
3822 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3825 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3827 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3834LogicalResult acc::ExitDataOp::verify() {
3838 if (getDataClauseOperands().empty())
3839 return emitError(
"at least one operand must be present in dataOperands on "
3840 "the exit data operation");
3844 if (getAsyncOperand() && getAsync())
3845 return emitError(
"async attribute cannot appear with asyncOperand");
3849 if (!getWaitOperands().empty() && getWait())
3850 return emitError(
"wait attribute cannot appear with waitOperands");
3852 if (getWaitDevnum() && getWaitOperands().empty())
3853 return emitError(
"wait_devnum cannot appear without waitOperands");
3858unsigned ExitDataOp::getNumDataOperands() {
3859 return getDataClauseOperands().size();
3862Value ExitDataOp::getDataOperand(
unsigned i) {
3863 unsigned numOptional = getIfCond() ? 1 : 0;
3864 numOptional += getAsyncOperand() ? 1 : 0;
3865 numOptional += getWaitDevnum() ? 1 : 0;
3866 return getOperand(getWaitOperands().size() + numOptional + i);
3871 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3874void ExitDataOp::addAsyncOnly(
MLIRContext *context,
3876 assert(effectiveDeviceTypes.empty());
3877 assert(!getAsyncAttr());
3878 assert(!getAsyncOperand());
3880 setAsyncAttr(mlir::UnitAttr::get(context));
3883void ExitDataOp::addAsyncOperand(
3886 assert(effectiveDeviceTypes.empty());
3887 assert(!getAsyncAttr());
3888 assert(!getAsyncOperand());
3890 getAsyncOperandMutable().append(newValue);
3895 assert(effectiveDeviceTypes.empty());
3896 assert(!getWaitAttr());
3897 assert(getWaitOperands().empty());
3898 assert(!getWaitDevnum());
3900 setWaitAttr(mlir::UnitAttr::get(context));
3903void ExitDataOp::addWaitOperands(
3906 assert(effectiveDeviceTypes.empty());
3907 assert(!getWaitAttr());
3908 assert(getWaitOperands().empty());
3909 assert(!getWaitDevnum());
3914 getWaitDevnumMutable().append(newValues.front());
3915 newValues = newValues.drop_front();
3918 getWaitOperandsMutable().append(newValues);
3925LogicalResult acc::EnterDataOp::verify() {
3929 if (getDataClauseOperands().empty())
3930 return emitError(
"at least one operand must be present in dataOperands on "
3931 "the enter data operation");
3935 if (getAsyncOperand() && getAsync())
3936 return emitError(
"async attribute cannot appear with asyncOperand");
3940 if (!getWaitOperands().empty() && getWait())
3941 return emitError(
"wait attribute cannot appear with waitOperands");
3943 if (getWaitDevnum() && getWaitOperands().empty())
3944 return emitError(
"wait_devnum cannot appear without waitOperands");
3946 for (
mlir::Value operand : getDataClauseOperands())
3947 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3948 operand.getDefiningOp()))
3949 return emitError(
"expect data entry operation as defining op");
3954unsigned EnterDataOp::getNumDataOperands() {
3955 return getDataClauseOperands().size();
3958Value EnterDataOp::getDataOperand(
unsigned i) {
3959 unsigned numOptional = getIfCond() ? 1 : 0;
3960 numOptional += getAsyncOperand() ? 1 : 0;
3961 numOptional += getWaitDevnum() ? 1 : 0;
3962 return getOperand(getWaitOperands().size() + numOptional + i);
3967 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
3970void EnterDataOp::addAsyncOnly(
3972 assert(effectiveDeviceTypes.empty());
3973 assert(!getAsyncAttr());
3974 assert(!getAsyncOperand());
3976 setAsyncAttr(mlir::UnitAttr::get(context));
3979void EnterDataOp::addAsyncOperand(
3982 assert(effectiveDeviceTypes.empty());
3983 assert(!getAsyncAttr());
3984 assert(!getAsyncOperand());
3986 getAsyncOperandMutable().append(newValue);
3989void EnterDataOp::addWaitOnly(
MLIRContext *context,
3991 assert(effectiveDeviceTypes.empty());
3992 assert(!getWaitAttr());
3993 assert(getWaitOperands().empty());
3994 assert(!getWaitDevnum());
3996 setWaitAttr(mlir::UnitAttr::get(context));
3999void EnterDataOp::addWaitOperands(
4002 assert(effectiveDeviceTypes.empty());
4003 assert(!getWaitAttr());
4004 assert(getWaitOperands().empty());
4005 assert(!getWaitDevnum());
4010 getWaitDevnumMutable().append(newValues.front());
4011 newValues = newValues.drop_front();
4014 getWaitOperandsMutable().append(newValues);
4021LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
4027LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
4033LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4040 if (
Value writeVal = op.getWriteOpVal()) {
4049LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
4051LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4057AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4058 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4060 return dyn_cast<AtomicReadOp>(getSecondOp());
4063AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4064 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4066 return dyn_cast<AtomicWriteOp>(getSecondOp());
4069AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4070 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4072 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4075LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
4081template <
typename Op>
4084 bool requireAtLeastOneOperand =
true) {
4085 if (operands.empty() && requireAtLeastOneOperand)
4088 "at least one operand must appear on the declare operation");
4091 if (isa<BlockArgument>(operand) ||
4092 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4093 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4094 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4095 operand.getDefiningOp()))
4097 "expect valid declare data entry operation or acc.getdeviceptr "
4101 assert(var &&
"declare operands can only be data entry operations which "
4104 std::optional<mlir::acc::DataClause> dataClauseOptional{
4106 assert(dataClauseOptional.has_value() &&
4107 "declare operands can only be data entry operations which must have "
4109 (
void)dataClauseOptional;
4115LogicalResult acc::DeclareEnterOp::verify() {
4123LogicalResult acc::DeclareExitOp::verify() {
4134LogicalResult acc::DeclareOp::verify() {
4143 acc::DeviceType dtype) {
4144 unsigned parallelism = 0;
4145 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4146 parallelism += op.hasWorker(dtype) ? 1 : 0;
4147 parallelism += op.hasVector(dtype) ? 1 : 0;
4148 parallelism += op.hasSeq(dtype) ? 1 : 0;
4152LogicalResult acc::RoutineOp::verify() {
4153 unsigned baseParallelism =
4156 if (baseParallelism > 1)
4157 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4158 "be present at the same time";
4160 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4162 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4163 if (dtype == acc::DeviceType::None)
4167 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4168 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4169 "be present at the same time for device_type `"
4170 << acc::stringifyDeviceType(dtype) <<
"`";
4177 mlir::ArrayAttr &bindIdName,
4178 mlir::ArrayAttr &bindStrName,
4179 mlir::ArrayAttr &deviceIdTypes,
4180 mlir::ArrayAttr &deviceStrTypes) {
4187 mlir::Attribute newAttr;
4188 bool isSymbolRefAttr;
4189 auto parseResult = parser.parseAttribute(newAttr);
4190 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4191 bindIdNameAttrs.push_back(symbolRefAttr);
4192 isSymbolRefAttr = true;
4193 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4194 bindStrNameAttrs.push_back(stringAttr);
4195 isSymbolRefAttr =
false;
4200 if (isSymbolRefAttr) {
4201 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4202 parser.getContext(), mlir::acc::DeviceType::None));
4204 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4205 parser.getContext(), mlir::acc::DeviceType::None));
4208 if (isSymbolRefAttr) {
4209 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4210 parser.parseRSquare())
4213 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4214 parser.parseRSquare())
4222 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4223 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4224 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4225 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4231 std::optional<mlir::ArrayAttr> bindIdName,
4232 std::optional<mlir::ArrayAttr> bindStrName,
4233 std::optional<mlir::ArrayAttr> deviceIdTypes,
4234 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4241 allBindNames.append(bindIdName->begin(), bindIdName->end());
4242 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4247 allBindNames.append(bindStrName->begin(), bindStrName->end());
4248 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4252 if (!allBindNames.empty())
4253 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4254 [&](
const auto &pair) {
4255 p << std::get<0>(pair);
4261 mlir::ArrayAttr &gang,
4262 mlir::ArrayAttr &gangDim,
4263 mlir::ArrayAttr &gangDimDeviceTypes) {
4266 gangDimDeviceTypeAttrs;
4267 bool needCommaBeforeOperands =
false;
4271 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4272 parser.
getContext(), mlir::acc::DeviceType::None));
4273 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4280 if (parser.parseAttribute(gangAttrs.emplace_back()))
4287 needCommaBeforeOperands =
true;
4290 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4294 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4295 parser.parseColon() ||
4296 parser.parseAttribute(gangDimAttrs.emplace_back()))
4298 if (succeeded(parser.parseOptionalLSquare())) {
4299 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4300 parser.parseRSquare())
4303 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4304 parser.getContext(), mlir::acc::DeviceType::None));
4310 if (
failed(parser.parseRParen()))
4313 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4314 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4315 gangDimDeviceTypes =
4316 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4322 std::optional<mlir::ArrayAttr> gang,
4323 std::optional<mlir::ArrayAttr> gangDim,
4324 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4327 gang->size() == 1) {
4328 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4329 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4341 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4342 [&](
const auto &pair) {
4343 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4344 p << std::get<0>(pair);
4352 mlir::ArrayAttr &deviceTypes) {
4356 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4357 parser.
getContext(), mlir::acc::DeviceType::None));
4358 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4365 if (parser.parseAttribute(attributes.emplace_back()))
4373 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4379 std::optional<mlir::ArrayAttr> deviceTypes) {
4382 auto deviceTypeAttr =
4383 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4384 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4393 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4399bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4401bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4405bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4407bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4411bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4413bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4417std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4418RoutineOp::getBindNameValue() {
4419 return getBindNameValue(mlir::acc::DeviceType::None);
4422std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4423RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4426 return std::nullopt;
4429 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4430 auto attr = (*getBindIdName())[*pos];
4431 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4432 assert(symbolRefAttr &&
"expected SymbolRef");
4433 return symbolRefAttr;
4436 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4437 auto attr = (*getBindStrName())[*pos];
4438 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4439 assert(stringAttr &&
"expected String");
4443 return std::nullopt;
4446bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4448bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4452std::optional<int64_t> RoutineOp::getGangDimValue() {
4453 return getGangDimValue(mlir::acc::DeviceType::None);
4456std::optional<int64_t>
4457RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4459 return std::nullopt;
4460 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4461 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4462 return intAttr.getInt();
4464 return std::nullopt;
4469 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4470 effectiveDeviceTypes));
4475 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4476 effectiveDeviceTypes));
4481 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4482 effectiveDeviceTypes));
4487 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4488 effectiveDeviceTypes));
4497 if (getGangDimAttr())
4498 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4499 if (getGangDimDeviceTypeAttr())
4500 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4502 assert(dimValues.size() == deviceTypes.size());
4504 if (effectiveDeviceTypes.empty()) {
4505 dimValues.push_back(
4506 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4507 deviceTypes.push_back(
4508 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4510 for (DeviceType dt : effectiveDeviceTypes) {
4511 dimValues.push_back(
4512 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4513 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4516 assert(dimValues.size() == deviceTypes.size());
4518 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4519 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4522void RoutineOp::addBindStrName(
MLIRContext *context,
4524 mlir::StringAttr val) {
4525 unsigned before = getBindStrNameDeviceTypeAttr()
4526 ? getBindStrNameDeviceTypeAttr().size()
4529 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4530 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4531 unsigned after = getBindStrNameDeviceTypeAttr().size();
4534 if (getBindStrNameAttr())
4535 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4536 for (
unsigned i = 0; i < after - before; ++i)
4537 vals.push_back(val);
4539 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4542void RoutineOp::addBindIDName(
MLIRContext *context,
4544 mlir::SymbolRefAttr val) {
4546 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4548 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4549 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4550 unsigned after = getBindIdNameDeviceTypeAttr().size();
4553 if (getBindIdNameAttr())
4554 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4555 for (
unsigned i = 0; i < after - before; ++i)
4556 vals.push_back(val);
4558 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4565LogicalResult acc::InitOp::verify() {
4569 return emitOpError(
"cannot be nested in a compute operation");
4573void acc::InitOp::addDeviceType(
MLIRContext *context,
4574 mlir::acc::DeviceType deviceType) {
4576 if (getDeviceTypesAttr())
4577 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4579 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4580 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4587LogicalResult acc::ShutdownOp::verify() {
4591 return emitOpError(
"cannot be nested in a compute operation");
4595void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4596 mlir::acc::DeviceType deviceType) {
4598 if (getDeviceTypesAttr())
4599 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4601 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4602 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4609LogicalResult acc::SetOp::verify() {
4613 return emitOpError(
"cannot be nested in a compute operation");
4614 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4615 return emitOpError(
"at least one default_async, device_num, or device_type "
4616 "operand must appear");
4624LogicalResult acc::UpdateOp::verify() {
4626 if (getDataClauseOperands().empty())
4627 return emitError(
"at least one value must be present in dataOperands");
4630 getAsyncOperandsDeviceTypeAttr(),
4635 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4636 getWaitOperandsDeviceTypeAttr(),
"wait")))
4642 for (
mlir::Value operand : getDataClauseOperands())
4643 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4644 operand.getDefiningOp()))
4645 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4651unsigned UpdateOp::getNumDataOperands() {
4652 return getDataClauseOperands().size();
4655Value UpdateOp::getDataOperand(
unsigned i) {
4657 numOptional += getIfCond() ? 1 : 0;
4658 return getOperand(getWaitOperands().size() + numOptional + i);
4663 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
4666bool UpdateOp::hasAsyncOnly() {
4667 return hasAsyncOnly(mlir::acc::DeviceType::None);
4670bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4675 return getAsyncValue(mlir::acc::DeviceType::None);
4678mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4688bool UpdateOp::hasWaitOnly() {
4689 return hasWaitOnly(mlir::acc::DeviceType::None);
4692bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4697 return getWaitValues(mlir::acc::DeviceType::None);
4701UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4703 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4704 getHasWaitDevnum(), deviceType);
4708 return getWaitDevnum(mlir::acc::DeviceType::None);
4711mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4713 getWaitOperandsSegments(), getHasWaitDevnum(),
4719 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4720 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4723void UpdateOp::addAsyncOperand(
4726 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4727 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4728 getAsyncOperandsMutable()));
4733 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4734 effectiveDeviceTypes));
4737void UpdateOp::addWaitOperands(
4742 if (getWaitOperandsSegments())
4743 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4745 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4746 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4747 getWaitOperandsMutable(), segments));
4748 setWaitOperandsSegments(segments);
4751 if (getHasWaitDevnumAttr())
4752 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4755 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4757 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4764LogicalResult acc::WaitOp::verify() {
4767 if (getAsyncOperand() && getAsync())
4768 return emitError(
"async attribute cannot appear with asyncOperand");
4770 if (getWaitDevnum() && getWaitOperands().empty())
4771 return emitError(
"wait_devnum cannot appear without waitOperands");
4776#define GET_OP_CLASSES
4777#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4779#define GET_ATTRDEF_CLASSES
4780#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4782#define GET_TYPEDEF_CLASSES
4783#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4794 .Case<ACC_DATA_ENTRY_OPS>(
4795 [&](
auto entry) {
return entry.getVarPtr(); })
4796 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4797 [&](
auto exit) {
return exit.getVarPtr(); })
4815 [&](
auto entry) {
return entry.getVarType(); })
4816 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4817 [&](
auto exit) {
return exit.getVarType(); })
4827 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4828 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
4838 [&](
auto dataClause) {
return dataClause.getAccVar(); })
4847 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
4857 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4859 dataClause.getBounds().begin(), dataClause.getBounds().end());
4871 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4873 dataClause.getAsyncOperands().begin(),
4874 dataClause.getAsyncOperands().end());
4885 return dataClause.getAsyncOperandsDeviceTypeAttr();
4893 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
4900 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
4907std::optional<mlir::acc::DataClause>
4912 .Case<ACC_DATA_ENTRY_OPS>(
4913 [&](
auto entry) {
return entry.getDataClause(); })
4921 [&](
auto entry) {
return entry.getImplicit(); })
4930 [&](
auto entry) {
return entry.getDataClauseOperands(); })
4932 return dataOperands;
4940 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
4942 return dataOperands;
4949 [&](
auto entry) {
return entry.getRecipeAttr(); })
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
static bool isComputeOperation(Operation *op)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.