25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
41static bool isScalarLikeType(
Type type) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
56struct MemRefPointerLikeModel
57 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
59 return cast<T>(pointer).getElementType();
62 mlir::acc::VariableTypeCategory
65 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
72 return mlir::acc::VariableTypeCategory::uncategorized;
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
81 return mlir::acc::VariableTypeCategory::uncategorized;
85 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree)
const {
92 auto memrefTy = cast<MemRefType>(pointer);
96 if (memrefTy.hasStaticShape()) {
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
105 if (originalVar && originalVar.
getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
132 Type varType)
const {
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc =
nullptr;
143 while (currentValue) {
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
152 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
158 if (
auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
172 if (isa<memref::AllocaOp>(originalAlloc)) {
176 if (isa<memref::AllocOp>(originalAlloc)) {
178 memref::DeallocOp::create(builder, loc, memrefValue);
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
209 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType)
const {
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
220 auto memrefTy = memrefValue.
getType();
223 if (memrefTy.getRank() != 0)
226 return memref::LoadOp::create(builder, loc, memrefValue);
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
239 auto memrefTy = memrefValue.getType();
242 if (memrefTy.getRank() != 0)
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
250struct LLVMPointerPointerLikeModel
251 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
252 LLVM::LLVMPointerType> {
255 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
257 Type valueType)
const {
262 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
265 bool genStore(Type pointer, OpBuilder &builder, Location loc,
267 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
272struct MemrefAddressOfGlobalModel
273 :
public AddressOfGlobalOpInterface::ExternalModel<
274 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
275 SymbolRefAttr getSymbol(Operation *op)
const {
276 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
277 return getGlobalOp.getNameAttr();
281struct MemrefGlobalVariableModel
282 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
284 bool isConstant(Operation *op)
const {
285 auto globalOp = cast<memref::GlobalOp>(op);
286 return globalOp.getConstant();
289 Region *getInitRegion(Operation *op)
const {
295struct GPULaunchOffloadRegionModel
296 :
public acc::OffloadRegionOpInterface::ExternalModel<
297 GPULaunchOffloadRegionModel, gpu::LaunchOp> {};
303mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
304 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
307 if (existingDeviceTypes)
308 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
310 if (newDeviceTypes.empty())
311 deviceTypes.push_back(
312 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
314 for (DeviceType dt : newDeviceTypes)
315 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
317 return mlir::ArrayAttr::get(context, deviceTypes);
326mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
327 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
332 if (existingDeviceTypes)
333 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
335 if (newDeviceTypes.empty()) {
336 argCollection.
append(arguments);
337 segments.push_back(arguments.size());
338 deviceTypes.push_back(
339 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
342 for (DeviceType dt : newDeviceTypes) {
343 argCollection.
append(arguments);
344 segments.push_back(arguments.size());
345 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
348 return mlir::ArrayAttr::get(context, deviceTypes);
352mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
353 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
357 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
358 newDeviceTypes, arguments,
359 argCollection, segments);
367void OpenACCDialect::initialize() {
370#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
373#define GET_ATTRDEF_LIST
374#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
377#define GET_TYPEDEF_LIST
378#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
384 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
386 UnrankedMemRefType::attachInterface<
387 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
388 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
392 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
394 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
395 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*
getContext());
423void ParallelOp::getSuccessorRegions(
435void KernelEnvironmentOp::getSuccessorRegions(
447void HostDataOp::getSuccessorRegions(
458 if (getUnstructured()) {
487 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
491 mlir::acc::DeviceType deviceType) {
495 for (
auto attr : *arrayAttr) {
496 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
497 if (deviceTypeAttr.getValue() == deviceType)
505 std::optional<mlir::ArrayAttr> deviceTypes) {
510 llvm::interleaveComma(*deviceTypes, p,
516 mlir::acc::DeviceType deviceType) {
517 unsigned segmentIdx = 0;
518 for (
auto attr : segments) {
519 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
520 if (deviceTypeAttr.getValue() == deviceType)
521 return std::make_optional(segmentIdx);
531 mlir::acc::DeviceType deviceType) {
533 return range.take_front(0);
534 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
535 int32_t nbOperandsBefore = 0;
536 for (
unsigned i = 0; i < *pos; ++i)
537 nbOperandsBefore += (*segments)[i];
538 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
540 return range.take_front(0);
547 std::optional<mlir::ArrayAttr> hasWaitDevnum,
548 mlir::acc::DeviceType deviceType) {
551 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
552 if (hasWaitDevnum->getValue()[*pos])
563 std::optional<mlir::ArrayAttr> hasWaitDevnum,
564 mlir::acc::DeviceType deviceType) {
569 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
570 if (hasWaitDevnum && *hasWaitDevnum) {
571 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
572 if (boolAttr.getValue())
573 return range.drop_front(1);
579template <
typename Op>
581 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
583 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
588 op.hasAsyncOnly(dtype))
590 "asyncOnly attribute cannot appear with asyncOperand");
595 op.hasWaitOnly(dtype))
596 return op.
emitError(
"wait attribute cannot appear with waitOperands");
601template <
typename Op>
604 return op.
emitError(
"must have var operand");
607 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
608 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
609 return op.
emitError(
"var must be mappable or pointer-like");
612 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
613 op.getVarType() == op.getVar().getType())
614 return op.
emitError(
"varType must capture the element type of var");
619template <
typename Op>
621 if (op.getVar().getType() != op.getAccVar().getType())
622 return op.
emitError(
"input and output types must match");
627template <
typename Op>
629 if (op.getModifiers() != acc::DataClauseModifier::none)
630 return op.
emitError(
"no data clause modifiers are allowed");
634template <
typename Op>
637 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
639 "invalid data clause modifiers: " +
640 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
645template <
typename OpT,
typename RecipeOpT>
646static LogicalResult
checkRecipe(OpT op, llvm::StringRef operandName) {
651 !std::is_same_v<OpT, acc::ReductionOp>)
654 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
656 return op->emitOpError() <<
"recipe expected for " << operandName;
661 return op->emitOpError()
662 <<
"expected symbol reference " << operandRecipe <<
" to point to a "
663 << operandName <<
" declaration";
684 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
705 if (failed(parser.
parseType(accVarType)))
715 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
727 mlir::TypeAttr &varTypeAttr) {
728 if (failed(parser.
parseType(varPtrType)))
739 varTypeAttr = mlir::TypeAttr::get(varType);
744 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
745 varTypeAttr = mlir::TypeAttr::get(
746 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
748 varTypeAttr = mlir::TypeAttr::get(varPtrType);
755 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
763 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
764 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
766 if (typeToCheckAgainst != varType) {
774 mlir::SymbolRefAttr &recipeAttr) {
781 mlir::SymbolRefAttr recipeAttr) {
788LogicalResult acc::DataBoundsOp::verify() {
789 auto extent = getExtent();
790 auto upperbound = getUpperbound();
791 if (!extent && !upperbound)
792 return emitError(
"expected extent or upperbound.");
799LogicalResult acc::PrivateOp::verify() {
802 "data clause associated with private operation must match its intent");
816LogicalResult acc::FirstprivateOp::verify() {
818 return emitError(
"data clause associated with firstprivate operation must "
825 *
this,
"firstprivate")))
833LogicalResult acc::FirstprivateMapInitialOp::verify() {
835 return emitError(
"data clause associated with firstprivate operation must "
847LogicalResult acc::ReductionOp::verify() {
849 return emitError(
"data clause associated with reduction operation must "
856 *
this,
"reduction")))
864LogicalResult acc::DevicePtrOp::verify() {
866 return emitError(
"data clause associated with deviceptr operation must "
880LogicalResult acc::PresentOp::verify() {
883 "data clause associated with present operation must match its intent");
896LogicalResult acc::CopyinOp::verify() {
898 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
903 "data clause associated with copyin operation must match its intent"
904 " or specify original clause this operation was decomposed from");
910 acc::DataClauseModifier::always |
911 acc::DataClauseModifier::capture)))
916bool acc::CopyinOp::isCopyinReadonly() {
917 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
918 acc::bitEnumContainsAny(getModifiers(),
919 acc::DataClauseModifier::readonly);
925LogicalResult acc::CreateOp::verify() {
932 "data clause associated with create operation must match its intent"
933 " or specify original clause this operation was decomposed from");
941 acc::DataClauseModifier::always |
942 acc::DataClauseModifier::capture)))
947bool acc::CreateOp::isCreateZero() {
949 return getDataClause() == acc::DataClause::acc_create_zero ||
951 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
957LogicalResult acc::NoCreateOp::verify() {
959 return emitError(
"data clause associated with no_create operation must "
973LogicalResult acc::AttachOp::verify() {
976 "data clause associated with attach operation must match its intent");
990LogicalResult acc::DeclareDeviceResidentOp::verify() {
991 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
992 return emitError(
"data clause associated with device_resident operation "
993 "must match its intent");
1007LogicalResult acc::DeclareLinkOp::verify() {
1010 "data clause associated with link operation must match its intent");
1023LogicalResult acc::CopyoutOp::verify() {
1030 "data clause associated with copyout operation must match its intent"
1031 " or specify original clause this operation was decomposed from");
1033 return emitError(
"must have both host and device pointers");
1039 acc::DataClauseModifier::always |
1040 acc::DataClauseModifier::capture)))
1045bool acc::CopyoutOp::isCopyoutZero() {
1046 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1047 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1053LogicalResult acc::DeleteOp::verify() {
1062 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1065 "data clause associated with delete operation must match its intent"
1066 " or specify original clause this operation was decomposed from");
1068 return emitError(
"must have device pointer");
1072 acc::DataClauseModifier::readonly |
1073 acc::DataClauseModifier::always |
1074 acc::DataClauseModifier::capture)))
1082LogicalResult acc::DetachOp::verify() {
1087 "data clause associated with detach operation must match its intent"
1088 " or specify original clause this operation was decomposed from");
1090 return emitError(
"must have device pointer");
1099LogicalResult acc::UpdateHostOp::verify() {
1104 "data clause associated with host operation must match its intent"
1105 " or specify original clause this operation was decomposed from");
1107 return emitError(
"must have both host and device pointers");
1120LogicalResult acc::UpdateDeviceOp::verify() {
1124 "data clause associated with device operation must match its intent"
1125 " or specify original clause this operation was decomposed from");
1138LogicalResult acc::UseDeviceOp::verify() {
1142 "data clause associated with use_device operation must match its intent"
1143 " or specify original clause this operation was decomposed from");
1156LogicalResult acc::CacheOp::verify() {
1161 "data clause associated with cache operation must match its intent"
1162 " or specify original clause this operation was decomposed from");
1172bool acc::CacheOp::isCacheReadonly() {
1173 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1174 acc::bitEnumContainsAny(getModifiers(),
1175 acc::DataClauseModifier::readonly);
1178template <
typename StructureOp>
1180 unsigned nRegions = 1) {
1183 for (
unsigned i = 0; i < nRegions; ++i)
1186 for (
Region *region : regions)
1194 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1201template <
typename OpTy>
1203 using OpRewritePattern<OpTy>::OpRewritePattern;
1205 LogicalResult matchAndRewrite(OpTy op,
1206 PatternRewriter &rewriter)
const override {
1208 Value ifCond = op.getIfCond();
1212 IntegerAttr constAttr;
1215 if (constAttr.getInt())
1216 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1228 assert(region.
hasOneBlock() &&
"expected single-block region");
1240template <
typename OpTy>
1241struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1242 using OpRewritePattern<OpTy>::OpRewritePattern;
1244 LogicalResult matchAndRewrite(OpTy op,
1245 PatternRewriter &rewriter)
const override {
1247 Value ifCond = op.getIfCond();
1251 IntegerAttr constAttr;
1254 if (constAttr.getInt())
1255 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1265struct RemoveEmptyKernelEnvironment
1267 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1269 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1270 PatternRewriter &rewriter)
const override {
1271 assert(op->getNumRegions() == 1 &&
"expected op to have one region");
1282 if (
auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1283 for (
auto attr : deviceTypeAttr) {
1284 if (
auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1285 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1292 if (
auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1293 for (
auto attr : hasDevnumAttr) {
1294 if (
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1295 if (boolAttr.getValue())
1302 if (
auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1303 if (segmentsAttr.size() > 1)
1309 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1336 for (
Value bound : bounds) {
1337 argTypes.push_back(bound.getType());
1338 argLocs.push_back(loc);
1345 Value privatizedValue;
1351 if (isa<MappableType>(varType)) {
1352 auto mappableTy = cast<MappableType>(varType);
1353 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1354 privatizedValue = mappableTy.generatePrivateInit(
1355 builder, loc, typedVar, varName, bounds, {}, needsFree);
1356 if (!privatizedValue)
1359 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1360 auto pointerLikeTy = cast<PointerLikeType>(varType);
1362 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1363 blockArgVar, needsFree);
1364 if (!privatizedValue)
1369 acc::YieldOp::create(builder, loc, privatizedValue);
1384 for (
Value bound : bounds) {
1385 copyArgTypes.push_back(bound.getType());
1386 copyArgLocs.push_back(loc);
1393 bool isMappable = isa<MappableType>(varType);
1394 bool isPointerLike = isa<PointerLikeType>(varType);
1397 if (isMappable && !isPointerLike)
1401 if (isPointerLike) {
1402 auto pointerLikeTy = cast<PointerLikeType>(varType);
1407 if (!pointerLikeTy.genCopy(
1414 acc::TerminatorOp::create(builder, loc);
1428 for (
Value bound : bounds) {
1429 destroyArgTypes.push_back(bound.getType());
1430 destroyArgLocs.push_back(loc);
1434 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1438 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1439 if (isa<MappableType>(varType)) {
1440 auto mappableTy = cast<MappableType>(varType);
1441 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
1444 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1445 auto pointerLikeTy = cast<PointerLikeType>(varType);
1446 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1450 acc::TerminatorOp::create(builder, loc);
1461 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1463 if (optional && region.
empty())
1467 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1471 return op->
emitOpError() <<
"expects " << regionName
1474 << regionType <<
" type";
1477 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1478 if (yieldOp.getOperands().size() != 1 ||
1479 yieldOp.getOperands().getTypes()[0] != type)
1480 return op->
emitOpError() <<
"expects " << regionName
1482 "yield a value of the "
1483 << regionType <<
" type";
1489LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1491 "privatization",
"init",
getType(),
1495 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1501std::optional<PrivateRecipeOp>
1503 StringRef recipeName,
Type varType,
1506 bool isMappable = isa<MappableType>(varType);
1507 bool isPointerLike = isa<PointerLikeType>(varType);
1510 if (!isMappable && !isPointerLike)
1511 return std::nullopt;
1516 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1519 bool needsFree =
false;
1520 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1521 varName, bounds, needsFree))) {
1523 return std::nullopt;
1530 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1531 Value allocRes = yieldOp.getOperand(0);
1533 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1534 varType, allocRes, bounds))) {
1536 return std::nullopt;
1543std::optional<PrivateRecipeOp>
1545 StringRef recipeName,
1546 FirstprivateRecipeOp firstprivRecipe) {
1549 auto varType = firstprivRecipe.getType();
1550 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1554 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1557 if (!firstprivRecipe.getDestroyRegion().empty()) {
1559 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1569LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1571 "privatization",
"init",
getType(),
1575 if (getCopyRegion().empty())
1576 return emitOpError() <<
"expects non-empty copy region";
1581 return emitOpError() <<
"expects copy region with two arguments of the "
1582 "privatization type";
1584 if (getDestroyRegion().empty())
1588 "privatization",
"destroy",
1595std::optional<FirstprivateRecipeOp>
1597 StringRef recipeName,
Type varType,
1600 bool isMappable = isa<MappableType>(varType);
1601 bool isPointerLike = isa<PointerLikeType>(varType);
1604 if (!isMappable && !isPointerLike)
1605 return std::nullopt;
1610 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1613 bool needsFree =
false;
1614 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1615 varName, bounds, needsFree))) {
1617 return std::nullopt;
1621 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1624 return std::nullopt;
1631 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1632 Value allocRes = yieldOp.getOperand(0);
1634 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1635 varType, allocRes, bounds))) {
1637 return std::nullopt;
1648LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1654 if (getCombinerRegion().empty())
1655 return emitOpError() <<
"expects non-empty combiner region";
1657 Block &reductionBlock = getCombinerRegion().
front();
1661 return emitOpError() <<
"expects combiner region with the first two "
1662 <<
"arguments of the reduction type";
1664 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1665 if (yieldOp.getOperands().size() != 1 ||
1666 yieldOp.getOperands().getTypes()[0] !=
getType())
1667 return emitOpError() <<
"expects combiner region to yield a value "
1668 "of the reduction type";
1679template <
typename Op>
1683 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1684 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1685 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1686 operand.getDefiningOp()))
1688 "expect data entry/exit operation or acc.getdeviceptr "
1693template <
typename OpT,
typename RecipeOpT>
1696 llvm::StringRef operandName) {
1699 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1701 <<
"expected " << operandName <<
" as defining op";
1702 if (!set.insert(operand).second)
1704 << operandName <<
" operand appears more than once";
1709unsigned ParallelOp::getNumDataOperands() {
1710 return getReductionOperands().size() + getPrivateOperands().size() +
1711 getFirstprivateOperands().size() + getDataClauseOperands().size();
1714Value ParallelOp::getDataOperand(
unsigned i) {
1716 numOptional += getNumGangs().size();
1717 numOptional += getNumWorkers().size();
1718 numOptional += getVectorLength().size();
1719 numOptional += getIfCond() ? 1 : 0;
1720 numOptional += getSelfCond() ? 1 : 0;
1721 return getOperand(getWaitOperands().size() + numOptional + i);
1724template <
typename Op>
1727 llvm::StringRef keyword) {
1728 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1729 return op.
emitOpError() << keyword <<
" operands count must match "
1730 << keyword <<
" device_type count";
1734template <
typename Op>
1737 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1738 std::size_t numOperandsInSegments = 0;
1739 std::size_t nbOfSegments = 0;
1742 for (
auto segCount : segments.
asArrayRef()) {
1743 if (maxInSegment != 0 && segCount > maxInSegment)
1744 return op.
emitOpError() << keyword <<
" expects a maximum of "
1745 << maxInSegment <<
" values per segment";
1746 numOperandsInSegments += segCount;
1751 if ((numOperandsInSegments != operands.size()) ||
1752 (!deviceTypes && !operands.empty()))
1754 << keyword <<
" operand count does not match count in segments";
1755 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1757 << keyword <<
" segment count does not match device_type count";
1761LogicalResult acc::ParallelOp::verify() {
1763 mlir::acc::PrivateRecipeOp>(
1764 *
this, getPrivateOperands(),
"private")))
1767 mlir::acc::FirstprivateRecipeOp>(
1768 *
this, getFirstprivateOperands(),
"firstprivate")))
1771 mlir::acc::ReductionRecipeOp>(
1772 *
this, getReductionOperands(),
"reduction")))
1776 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1777 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1781 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1782 getWaitOperandsDeviceTypeAttr(),
"wait")))
1786 getNumWorkersDeviceTypeAttr(),
1791 getVectorLengthDeviceTypeAttr(),
1796 getAsyncOperandsDeviceTypeAttr(),
1809 mlir::acc::DeviceType deviceType) {
1812 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1817bool acc::ParallelOp::hasAsyncOnly() {
1818 return hasAsyncOnly(mlir::acc::DeviceType::None);
1821bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1826 return getAsyncValue(mlir::acc::DeviceType::None);
1829mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1834mlir::Value acc::ParallelOp::getNumWorkersValue() {
1835 return getNumWorkersValue(mlir::acc::DeviceType::None);
1839acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1844mlir::Value acc::ParallelOp::getVectorLengthValue() {
1845 return getVectorLengthValue(mlir::acc::DeviceType::None);
1849acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1851 getVectorLength(), deviceType);
1855 return getNumGangsValues(mlir::acc::DeviceType::None);
1859ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1861 getNumGangsSegments(), deviceType);
1864bool acc::ParallelOp::hasWaitOnly() {
1865 return hasWaitOnly(mlir::acc::DeviceType::None);
1868bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1873 return getWaitValues(mlir::acc::DeviceType::None);
1877ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1879 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1880 getHasWaitDevnum(), deviceType);
1884 return getWaitDevnum(mlir::acc::DeviceType::None);
1887mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1889 getWaitOperandsSegments(), getHasWaitDevnum(),
1904 odsBuilder, odsState, asyncOperands,
nullptr,
1905 nullptr, waitOperands,
nullptr,
1907 nullptr, numGangs,
nullptr,
1908 nullptr, numWorkers,
1909 nullptr, vectorLength,
1910 nullptr, ifCond, selfCond,
1911 nullptr, reductionOperands, gangPrivateOperands,
1912 gangFirstPrivateOperands, dataClauseOperands,
1916void acc::ParallelOp::addNumWorkersOperand(
1919 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1920 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1921 getNumWorkersMutable()));
1923void acc::ParallelOp::addVectorLengthOperand(
1926 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1927 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1928 getVectorLengthMutable()));
1931void acc::ParallelOp::addAsyncOnly(
1933 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1934 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1937void acc::ParallelOp::addAsyncOperand(
1940 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1941 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1942 getAsyncOperandsMutable()));
1945void acc::ParallelOp::addNumGangsOperands(
1949 if (getNumGangsSegments())
1950 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1952 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1953 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1954 getNumGangsMutable(), segments));
1956 setNumGangsSegments(segments);
1958void acc::ParallelOp::addWaitOnly(
1960 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1961 effectiveDeviceTypes));
1963void acc::ParallelOp::addWaitOperands(
1968 if (getWaitOperandsSegments())
1969 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1971 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1972 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1973 getWaitOperandsMutable(), segments));
1974 setWaitOperandsSegments(segments);
1977 if (getHasWaitDevnumAttr())
1978 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1981 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1983 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1986void acc::ParallelOp::addPrivatization(
MLIRContext *context,
1987 mlir::acc::PrivateOp op,
1988 mlir::acc::PrivateRecipeOp recipe) {
1989 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1990 getPrivateOperandsMutable().append(op.getResult());
1993void acc::ParallelOp::addFirstPrivatization(
1994 MLIRContext *context, mlir::acc::FirstprivateOp op,
1995 mlir::acc::FirstprivateRecipeOp recipe) {
1996 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1997 getFirstprivateOperandsMutable().append(op.getResult());
2000void acc::ParallelOp::addReduction(
MLIRContext *context,
2001 mlir::acc::ReductionOp op,
2002 mlir::acc::ReductionRecipeOp recipe) {
2003 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2004 getReductionOperandsMutable().append(op.getResult());
2019 int32_t crtOperandsSize = operands.size();
2022 if (parser.parseOperand(operands.emplace_back()) ||
2023 parser.parseColonType(types.emplace_back()))
2028 seg.push_back(operands.size() - crtOperandsSize);
2038 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2039 parser.
getContext(), mlir::acc::DeviceType::None));
2045 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2052 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2053 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2054 p <<
" [" << attr <<
"]";
2059 std::optional<mlir::ArrayAttr> deviceTypes,
2060 std::optional<mlir::DenseI32ArrayAttr> segments) {
2062 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2064 llvm::interleaveComma(
2065 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2066 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2086 int32_t crtOperandsSize = operands.size();
2090 if (parser.parseOperand(operands.emplace_back()) ||
2091 parser.parseColonType(types.emplace_back()))
2097 seg.push_back(operands.size() - crtOperandsSize);
2107 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2108 parser.
getContext(), mlir::acc::DeviceType::None));
2114 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2123 std::optional<mlir::DenseI32ArrayAttr> segments) {
2125 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2127 llvm::interleaveComma(
2128 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2129 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2142 mlir::ArrayAttr &keywordOnly) {
2146 bool needCommaBeforeOperands =
false;
2150 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2151 parser.
getContext(), mlir::acc::DeviceType::None));
2152 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2159 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2166 needCommaBeforeOperands =
true;
2169 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2176 int32_t crtOperandsSize = operands.size();
2188 if (parser.parseOperand(operands.emplace_back()) ||
2189 parser.parseColonType(types.emplace_back()))
2195 seg.push_back(operands.size() - crtOperandsSize);
2205 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2206 parser.
getContext(), mlir::acc::DeviceType::None));
2213 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2214 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2216 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2224 if (attrs->size() != 1)
2226 if (
auto deviceTypeAttr =
2227 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2228 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2234 std::optional<mlir::ArrayAttr> deviceTypes,
2235 std::optional<mlir::DenseI32ArrayAttr> segments,
2236 std::optional<mlir::ArrayAttr> hasDevNum,
2237 std::optional<mlir::ArrayAttr> keywordOnly) {
2250 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2252 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2253 if (boolAttr && boolAttr.getValue())
2255 llvm::interleaveComma(
2256 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2257 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2274 if (parser.parseOperand(operands.emplace_back()) ||
2275 parser.parseColonType(types.emplace_back()))
2277 if (succeeded(parser.parseOptionalLSquare())) {
2278 if (parser.parseAttribute(attributes.emplace_back()) ||
2279 parser.parseRSquare())
2282 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2283 parser.getContext(), mlir::acc::DeviceType::None));
2290 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2297 std::optional<mlir::ArrayAttr> deviceTypes) {
2300 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2301 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2310 mlir::ArrayAttr &keywordOnlyDeviceType) {
2313 bool needCommaBeforeOperands =
false;
2317 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2318 parser.
getContext(), mlir::acc::DeviceType::None));
2319 keywordOnlyDeviceType =
2320 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2328 if (parser.parseAttribute(
2329 keywordOnlyDeviceTypeAttributes.emplace_back()))
2336 needCommaBeforeOperands =
true;
2339 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2344 if (parser.parseOperand(operands.emplace_back()) ||
2345 parser.parseColonType(types.emplace_back()))
2347 if (succeeded(parser.parseOptionalLSquare())) {
2348 if (parser.parseAttribute(attributes.emplace_back()) ||
2349 parser.parseRSquare())
2352 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2353 parser.getContext(), mlir::acc::DeviceType::None));
2359 if (
failed(parser.parseRParen()))
2364 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2371 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2373 if (operands.begin() == operands.end() &&
2389 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2390 mlir::Type &operandType, mlir::UnitAttr &attr) {
2393 attr = mlir::UnitAttr::get(parser.
getContext());
2403 if (failed(parser.
parseType(operandType)))
2413 std::optional<mlir::Value> operand,
2415 mlir::UnitAttr attr) {
2432 attr = mlir::UnitAttr::get(parser.
getContext());
2437 if (parser.parseOperand(operands.emplace_back()))
2445 if (parser.parseType(types.emplace_back()))
2460 mlir::UnitAttr attr) {
2465 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2467 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2473 mlir::acc::CombinedConstructsTypeAttr &attr) {
2475 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2476 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2478 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2479 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2481 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2482 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2485 "expected compute construct name");
2493 mlir::acc::CombinedConstructsTypeAttr attr) {
2495 switch (attr.getValue()) {
2496 case mlir::acc::CombinedConstructsType::KernelsLoop:
2499 case mlir::acc::CombinedConstructsType::ParallelLoop:
2502 case mlir::acc::CombinedConstructsType::SerialLoop:
2513unsigned SerialOp::getNumDataOperands() {
2514 return getReductionOperands().size() + getPrivateOperands().size() +
2515 getFirstprivateOperands().size() + getDataClauseOperands().size();
2518Value SerialOp::getDataOperand(
unsigned i) {
2520 numOptional += getIfCond() ? 1 : 0;
2521 numOptional += getSelfCond() ? 1 : 0;
2522 return getOperand(getWaitOperands().size() + numOptional + i);
2525bool acc::SerialOp::hasAsyncOnly() {
2526 return hasAsyncOnly(mlir::acc::DeviceType::None);
2529bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2534 return getAsyncValue(mlir::acc::DeviceType::None);
2537mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2542bool acc::SerialOp::hasWaitOnly() {
2543 return hasWaitOnly(mlir::acc::DeviceType::None);
2546bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2551 return getWaitValues(mlir::acc::DeviceType::None);
2555SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2557 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2558 getHasWaitDevnum(), deviceType);
2562 return getWaitDevnum(mlir::acc::DeviceType::None);
2565mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2567 getWaitOperandsSegments(), getHasWaitDevnum(),
2571LogicalResult acc::SerialOp::verify() {
2573 mlir::acc::PrivateRecipeOp>(
2574 *
this, getPrivateOperands(),
"private")))
2577 mlir::acc::FirstprivateRecipeOp>(
2578 *
this, getFirstprivateOperands(),
"firstprivate")))
2581 mlir::acc::ReductionRecipeOp>(
2582 *
this, getReductionOperands(),
"reduction")))
2586 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2587 getWaitOperandsDeviceTypeAttr(),
"wait")))
2591 getAsyncOperandsDeviceTypeAttr(),
2601void acc::SerialOp::addAsyncOnly(
2603 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2604 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2607void acc::SerialOp::addAsyncOperand(
2610 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2611 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2612 getAsyncOperandsMutable()));
2615void acc::SerialOp::addWaitOnly(
2617 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2618 effectiveDeviceTypes));
2620void acc::SerialOp::addWaitOperands(
2625 if (getWaitOperandsSegments())
2626 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2628 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2629 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2630 getWaitOperandsMutable(), segments));
2631 setWaitOperandsSegments(segments);
2634 if (getHasWaitDevnumAttr())
2635 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2638 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2640 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2643void acc::SerialOp::addPrivatization(
MLIRContext *context,
2644 mlir::acc::PrivateOp op,
2645 mlir::acc::PrivateRecipeOp recipe) {
2646 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2647 getPrivateOperandsMutable().append(op.getResult());
2650void acc::SerialOp::addFirstPrivatization(
2651 MLIRContext *context, mlir::acc::FirstprivateOp op,
2652 mlir::acc::FirstprivateRecipeOp recipe) {
2653 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2654 getFirstprivateOperandsMutable().append(op.getResult());
2657void acc::SerialOp::addReduction(
MLIRContext *context,
2658 mlir::acc::ReductionOp op,
2659 mlir::acc::ReductionRecipeOp recipe) {
2660 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2661 getReductionOperandsMutable().append(op.getResult());
2668unsigned KernelsOp::getNumDataOperands() {
2669 return getDataClauseOperands().size();
2672Value KernelsOp::getDataOperand(
unsigned i) {
2674 numOptional += getWaitOperands().size();
2675 numOptional += getNumGangs().size();
2676 numOptional += getNumWorkers().size();
2677 numOptional += getVectorLength().size();
2678 numOptional += getIfCond() ? 1 : 0;
2679 numOptional += getSelfCond() ? 1 : 0;
2680 return getOperand(numOptional + i);
2683bool acc::KernelsOp::hasAsyncOnly() {
2684 return hasAsyncOnly(mlir::acc::DeviceType::None);
2687bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2692 return getAsyncValue(mlir::acc::DeviceType::None);
2695mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2701 return getNumWorkersValue(mlir::acc::DeviceType::None);
2705acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2710mlir::Value acc::KernelsOp::getVectorLengthValue() {
2711 return getVectorLengthValue(mlir::acc::DeviceType::None);
2715acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2717 getVectorLength(), deviceType);
2721 return getNumGangsValues(mlir::acc::DeviceType::None);
2725KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2727 getNumGangsSegments(), deviceType);
2730bool acc::KernelsOp::hasWaitOnly() {
2731 return hasWaitOnly(mlir::acc::DeviceType::None);
2734bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2739 return getWaitValues(mlir::acc::DeviceType::None);
2743KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2745 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2746 getHasWaitDevnum(), deviceType);
2750 return getWaitDevnum(mlir::acc::DeviceType::None);
2753mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2755 getWaitOperandsSegments(), getHasWaitDevnum(),
2759LogicalResult acc::KernelsOp::verify() {
2761 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2762 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2766 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2767 getWaitOperandsDeviceTypeAttr(),
"wait")))
2771 getNumWorkersDeviceTypeAttr(),
2776 getVectorLengthDeviceTypeAttr(),
2781 getAsyncOperandsDeviceTypeAttr(),
2791void acc::KernelsOp::addPrivatization(
MLIRContext *context,
2792 mlir::acc::PrivateOp op,
2793 mlir::acc::PrivateRecipeOp recipe) {
2794 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2795 getPrivateOperandsMutable().append(op.getResult());
2798void acc::KernelsOp::addFirstPrivatization(
2799 MLIRContext *context, mlir::acc::FirstprivateOp op,
2800 mlir::acc::FirstprivateRecipeOp recipe) {
2801 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2802 getFirstprivateOperandsMutable().append(op.getResult());
2805void acc::KernelsOp::addReduction(
MLIRContext *context,
2806 mlir::acc::ReductionOp op,
2807 mlir::acc::ReductionRecipeOp recipe) {
2808 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2809 getReductionOperandsMutable().append(op.getResult());
2812void acc::KernelsOp::addNumWorkersOperand(
2815 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2816 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2817 getNumWorkersMutable()));
2820void acc::KernelsOp::addVectorLengthOperand(
2823 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2824 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2825 getVectorLengthMutable()));
2827void acc::KernelsOp::addAsyncOnly(
2829 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2830 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2833void acc::KernelsOp::addAsyncOperand(
2836 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2837 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2838 getAsyncOperandsMutable()));
2841void acc::KernelsOp::addNumGangsOperands(
2845 if (getNumGangsSegmentsAttr())
2846 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2848 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2849 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2850 getNumGangsMutable(), segments));
2852 setNumGangsSegments(segments);
2855void acc::KernelsOp::addWaitOnly(
2857 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2858 effectiveDeviceTypes));
2860void acc::KernelsOp::addWaitOperands(
2865 if (getWaitOperandsSegments())
2866 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2868 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2869 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2870 getWaitOperandsMutable(), segments));
2871 setWaitOperandsSegments(segments);
2874 if (getHasWaitDevnumAttr())
2875 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2878 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2880 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2887LogicalResult acc::HostDataOp::verify() {
2888 if (getDataClauseOperands().empty())
2889 return emitError(
"at least one operand must appear on the host_data "
2892 for (
mlir::Value operand : getDataClauseOperands())
2893 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2894 return emitError(
"expect data entry operation as defining op");
2900 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2907void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2909 results.
add<RemoveEmptyKernelEnvironment>(context);
2921 bool &needCommaBetweenValues,
bool &newValue) {
2928 attributes.push_back(gangArgType);
2929 needCommaBetweenValues =
true;
2940 mlir::ArrayAttr &gangOnlyDeviceType) {
2945 bool needCommaBetweenValues =
false;
2946 bool needCommaBeforeOperands =
false;
2950 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2951 parser.
getContext(), mlir::acc::DeviceType::None));
2952 gangOnlyDeviceType =
2953 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
2961 if (parser.parseAttribute(
2962 gangOnlyDeviceTypeAttributes.emplace_back()))
2969 needCommaBeforeOperands =
true;
2972 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
2973 mlir::acc::GangArgType::Num);
2974 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
2975 mlir::acc::GangArgType::Dim);
2976 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2977 parser.
getContext(), mlir::acc::GangArgType::Static);
2980 if (needCommaBeforeOperands) {
2981 needCommaBeforeOperands =
false;
2988 int32_t crtOperandsSize = gangOperands.size();
2990 bool newValue =
false;
2991 bool needValue =
false;
2992 if (needCommaBetweenValues) {
3000 gangOperands, gangOperandsType,
3001 gangArgTypeAttributes, argNum,
3002 needCommaBetweenValues, newValue)))
3005 gangOperands, gangOperandsType,
3006 gangArgTypeAttributes, argDim,
3007 needCommaBetweenValues, newValue)))
3009 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3010 gangOperands, gangOperandsType,
3011 gangArgTypeAttributes, argStatic,
3012 needCommaBetweenValues, newValue)))
3015 if (!newValue && needValue) {
3017 "new value expected after comma");
3025 if (gangOperands.empty())
3028 "expect at least one of num, dim or static values");
3034 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
3038 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3039 parser.
getContext(), mlir::acc::DeviceType::None));
3042 seg.push_back(gangOperands.size() - crtOperandsSize);
3050 gangArgTypeAttributes.end());
3051 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
3052 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
3055 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3056 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
3064 std::optional<mlir::ArrayAttr> gangArgTypes,
3065 std::optional<mlir::ArrayAttr> deviceTypes,
3066 std::optional<mlir::DenseI32ArrayAttr> segments,
3067 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3069 if (operands.begin() == operands.end() &&
3084 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
3086 llvm::interleaveComma(
3087 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
3088 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3089 (*gangArgTypes)[opIdx]);
3090 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3091 p << LoopOp::getGangNumKeyword();
3092 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3093 p << LoopOp::getGangDimKeyword();
3094 else if (gangArgTypeAttr.getValue() ==
3095 mlir::acc::GangArgType::Static)
3096 p << LoopOp::getGangStaticKeyword();
3097 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
3108 std::optional<mlir::ArrayAttr> segments,
3109 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3112 for (
auto attr : *segments) {
3113 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3114 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3122static std::optional<mlir::acc::DeviceType>
3124 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3126 return std::nullopt;
3127 for (
auto attr : deviceTypes) {
3128 auto deviceTypeAttr =
3129 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3130 if (!deviceTypeAttr)
3131 return mlir::acc::DeviceType::None;
3132 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3133 return deviceTypeAttr.getValue();
3135 return std::nullopt;
3138LogicalResult acc::LoopOp::verify() {
3139 if (getUpperbound().size() != getStep().size())
3140 return emitError() <<
"number of upperbounds expected to be the same as "
3143 if (getUpperbound().size() != getLowerbound().size())
3144 return emitError() <<
"number of upperbounds expected to be the same as "
3145 "number of lowerbounds";
3147 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3148 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3149 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3150 <<
" as upperbound size";
3153 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3154 return emitOpError() <<
"collapse device_type attr must be define when"
3155 <<
" collapse attr is present";
3157 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3158 getCollapseAttr().getValue().size() !=
3159 getCollapseDeviceTypeAttr().getValue().size())
3160 return emitOpError() <<
"collapse attribute count must match collapse"
3161 <<
" device_type count";
3162 if (
auto duplicateDeviceType =
checkDeviceTypes(getCollapseDeviceTypeAttr()))
3164 << acc::stringifyDeviceType(*duplicateDeviceType)
3165 <<
"` found in collapseDeviceType attribute";
3168 if (!getGangOperands().empty()) {
3169 if (!getGangOperandsArgType())
3170 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3171 <<
" when gang operands are present";
3173 if (getGangOperands().size() !=
3174 getGangOperandsArgTypeAttr().getValue().size())
3175 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3176 <<
" gangOperands count";
3178 if (getGangAttr()) {
3181 << acc::stringifyDeviceType(*duplicateDeviceType)
3182 <<
"` found in gang attribute";
3186 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3187 getGangOperandsDeviceTypeAttr(),
"gang")))
3193 << acc::stringifyDeviceType(*duplicateDeviceType)
3194 <<
"` found in worker attribute";
3195 if (
auto duplicateDeviceType =
3198 << acc::stringifyDeviceType(*duplicateDeviceType)
3199 <<
"` found in workerNumOperandsDeviceType attribute";
3201 getWorkerNumOperandsDeviceTypeAttr(),
3208 << acc::stringifyDeviceType(*duplicateDeviceType)
3209 <<
"` found in vector attribute";
3210 if (
auto duplicateDeviceType =
3213 << acc::stringifyDeviceType(*duplicateDeviceType)
3214 <<
"` found in vectorOperandsDeviceType attribute";
3216 getVectorOperandsDeviceTypeAttr(),
3221 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3222 getTileOperandsDeviceTypeAttr(),
"tile")))
3226 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3230 return emitError() <<
"only one of auto, independent, seq can be present "
3236 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3237 return attr.getValue() == mlir::acc::DeviceType::None;
3239 bool hasDefaultSeq =
3241 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3244 bool hasDefaultIndependent =
3245 getIndependentAttr()
3247 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3250 bool hasDefaultAuto =
3252 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3255 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3257 <<
"at least one of auto, independent, seq must be present";
3262 for (
auto attr : getSeqAttr()) {
3263 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3264 if (hasVector(deviceTypeAttr.getValue()) ||
3265 getVectorValue(deviceTypeAttr.getValue()) ||
3266 hasWorker(deviceTypeAttr.getValue()) ||
3267 getWorkerValue(deviceTypeAttr.getValue()) ||
3268 hasGang(deviceTypeAttr.getValue()) ||
3269 getGangValue(mlir::acc::GangArgType::Num,
3270 deviceTypeAttr.getValue()) ||
3271 getGangValue(mlir::acc::GangArgType::Dim,
3272 deviceTypeAttr.getValue()) ||
3273 getGangValue(mlir::acc::GangArgType::Static,
3274 deviceTypeAttr.getValue()))
3275 return emitError() <<
"gang, worker or vector cannot appear with seq";
3280 mlir::acc::PrivateRecipeOp>(
3281 *
this, getPrivateOperands(),
"private")))
3285 mlir::acc::FirstprivateRecipeOp>(
3286 *
this, getFirstprivateOperands(),
"firstprivate")))
3290 mlir::acc::ReductionRecipeOp>(
3291 *
this, getReductionOperands(),
"reduction")))
3294 if (getCombined().has_value() &&
3295 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3296 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3297 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3298 return emitError(
"unexpected combined constructs attribute");
3302 if (getRegion().empty())
3303 return emitError(
"expected non-empty body.");
3305 if (getUnstructured()) {
3306 if (!isContainerLike())
3308 "unstructured acc.loop must not have induction variables");
3309 }
else if (isContainerLike()) {
3313 uint64_t collapseCount = getCollapseValue().value_or(1);
3314 if (getCollapseAttr()) {
3315 for (
auto collapseEntry : getCollapseAttr()) {
3316 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3317 if (intAttr.getValue().getZExtValue() > collapseCount)
3318 collapseCount = intAttr.getValue().getZExtValue();
3326 bool foundSibling =
false;
3328 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3330 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3332 foundSibling =
true;
3337 expectedParent = op;
3340 if (collapseCount == 0)
3346 return emitError(
"found sibling loops inside container-like acc.loop");
3347 if (collapseCount != 0)
3348 return emitError(
"failed to find enough loop-like operations inside "
3349 "container-like acc.loop");
3355unsigned LoopOp::getNumDataOperands() {
3356 return getReductionOperands().size() + getPrivateOperands().size() +
3357 getFirstprivateOperands().size();
3360Value LoopOp::getDataOperand(
unsigned i) {
3361 unsigned numOptional =
3362 getLowerbound().size() + getUpperbound().size() + getStep().size();
3363 numOptional += getGangOperands().size();
3364 numOptional += getVectorOperands().size();
3365 numOptional += getWorkerNumOperands().size();
3366 numOptional += getTileOperands().size();
3367 numOptional += getCacheOperands().size();
3368 return getOperand(numOptional + i);
3371bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3373bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3377bool LoopOp::hasIndependent() {
3378 return hasIndependent(mlir::acc::DeviceType::None);
3381bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3385bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3387bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3392 return getVectorValue(mlir::acc::DeviceType::None);
3395mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3397 getVectorOperands(), deviceType);
3400bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3402bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3407 return getWorkerValue(mlir::acc::DeviceType::None);
3410mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3412 getWorkerNumOperands(), deviceType);
3415bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3417bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3422 return getTileValues(mlir::acc::DeviceType::None);
3426LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3428 getTileOperandsSegments(), deviceType);
3431std::optional<int64_t> LoopOp::getCollapseValue() {
3432 return getCollapseValue(mlir::acc::DeviceType::None);
3435std::optional<int64_t>
3436LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3437 if (!getCollapseAttr())
3438 return std::nullopt;
3439 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3441 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3442 return intAttr.getValue().getZExtValue();
3444 return std::nullopt;
3447mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3448 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3451mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3452 mlir::acc::DeviceType deviceType) {
3453 if (getGangOperands().empty())
3455 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3456 int32_t nbOperandsBefore = 0;
3457 for (
unsigned i = 0; i < *pos; ++i)
3458 nbOperandsBefore += (*getGangOperandsSegments())[i];
3461 .drop_front(nbOperandsBefore)
3462 .take_front((*getGangOperandsSegments())[*pos]);
3464 int32_t argTypeIdx = nbOperandsBefore;
3465 for (
auto value : values) {
3466 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3467 (*getGangOperandsArgType())[argTypeIdx]);
3468 if (gangArgTypeAttr.getValue() == gangArgType)
3476bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3478bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3483 return {&getRegion()};
3527 if (!regionArgs.empty()) {
3528 p << acc::LoopOp::getControlKeyword() <<
"(";
3529 llvm::interleaveComma(regionArgs, p,
3531 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3532 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3533 <<
" : " << stepType <<
") ";
3540 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3541 effectiveDeviceTypes));
3544void acc::LoopOp::addIndependent(
3546 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3547 context, getIndependentAttr(), effectiveDeviceTypes));
3552 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3553 effectiveDeviceTypes));
3556void acc::LoopOp::setCollapseForDeviceTypes(
3558 llvm::APInt value) {
3562 assert((getCollapseAttr() ==
nullptr) ==
3563 (getCollapseDeviceTypeAttr() ==
nullptr));
3564 assert(value.getBitWidth() == 64);
3566 if (getCollapseAttr()) {
3567 for (
const auto &existing :
3568 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3569 newValues.push_back(std::get<0>(existing));
3570 newDeviceTypes.push_back(std::get<1>(existing));
3574 if (effectiveDeviceTypes.empty()) {
3577 newValues.push_back(
3578 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3579 newDeviceTypes.push_back(
3580 acc::DeviceTypeAttr::get(context, DeviceType::None));
3582 for (DeviceType dt : effectiveDeviceTypes) {
3583 newValues.push_back(
3584 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3585 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3589 setCollapseAttr(ArrayAttr::get(context, newValues));
3590 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3593void acc::LoopOp::setTileForDeviceTypes(
3597 if (getTileOperandsSegments())
3598 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3600 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3601 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3602 getTileOperandsMutable(), segments));
3604 setTileOperandsSegments(segments);
3607void acc::LoopOp::addVectorOperand(
3610 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3611 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3612 newValue, getVectorOperandsMutable()));
3615void acc::LoopOp::addEmptyVector(
3617 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3618 effectiveDeviceTypes));
3621void acc::LoopOp::addWorkerNumOperand(
3624 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3625 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3626 newValue, getWorkerNumOperandsMutable()));
3629void acc::LoopOp::addEmptyWorker(
3631 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3632 effectiveDeviceTypes));
3635void acc::LoopOp::addEmptyGang(
3637 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3638 effectiveDeviceTypes));
3641bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3642 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3643 return attr.getValue() == dt;
3645 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3646 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3649 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3651 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3653 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3659bool acc::LoopOp::hasDefaultGangWorkerVector() {
3660 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3661 hasGang() || getGangValue(GangArgType::Num) ||
3662 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3666acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3667 if (hasSeq(deviceType))
3668 return LoopParMode::loop_seq;
3669 if (hasAuto(deviceType))
3670 return LoopParMode::loop_auto;
3671 if (hasIndependent(deviceType))
3672 return LoopParMode::loop_independent;
3674 return LoopParMode::loop_seq;
3676 return LoopParMode::loop_auto;
3677 assert(hasIndependent() &&
3678 "loop must have default auto, seq, or independent");
3679 return LoopParMode::loop_independent;
3682void acc::LoopOp::addGangOperands(
3687 getGangOperandsSegments())
3688 llvm::copy(*existingSegments, std::back_inserter(segments));
3690 unsigned beforeCount = segments.size();
3692 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3693 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3694 getGangOperandsMutable(), segments));
3696 setGangOperandsSegments(segments);
3703 unsigned numAdded = segments.size() - beforeCount;
3707 if (getGangOperandsArgTypeAttr())
3708 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3710 for (
auto i : llvm::index_range(0u, numAdded)) {
3711 llvm::transform(argTypes, std::back_inserter(gangTypes),
3712 [=](mlir::acc::GangArgType gangTy) {
3713 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3718 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3722void acc::LoopOp::addPrivatization(
MLIRContext *context,
3723 mlir::acc::PrivateOp op,
3724 mlir::acc::PrivateRecipeOp recipe) {
3725 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3726 getPrivateOperandsMutable().append(op.getResult());
3729void acc::LoopOp::addFirstPrivatization(
3730 MLIRContext *context, mlir::acc::FirstprivateOp op,
3731 mlir::acc::FirstprivateRecipeOp recipe) {
3732 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3733 getFirstprivateOperandsMutable().append(op.getResult());
3736void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3737 mlir::acc::ReductionRecipeOp recipe) {
3738 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3739 getReductionOperandsMutable().append(op.getResult());
3746LogicalResult acc::DataOp::verify() {
3751 return emitError(
"at least one operand or the default attribute "
3752 "must appear on the data operation");
3754 for (
mlir::Value operand : getDataClauseOperands())
3755 if (isa<BlockArgument>(operand) ||
3756 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3757 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3758 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3759 operand.getDefiningOp()))
3760 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3769unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3771Value DataOp::getDataOperand(
unsigned i) {
3772 unsigned numOptional = getIfCond() ? 1 : 0;
3774 numOptional += getWaitOperands().size();
3775 return getOperand(numOptional + i);
3778bool acc::DataOp::hasAsyncOnly() {
3779 return hasAsyncOnly(mlir::acc::DeviceType::None);
3782bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3787 return getAsyncValue(mlir::acc::DeviceType::None);
3790mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3795bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
3797bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3802 return getWaitValues(mlir::acc::DeviceType::None);
3806DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3808 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3809 getHasWaitDevnum(), deviceType);
3813 return getWaitDevnum(mlir::acc::DeviceType::None);
3816mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3818 getWaitOperandsSegments(), getHasWaitDevnum(),
3822void acc::DataOp::addAsyncOnly(
3824 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3825 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3828void acc::DataOp::addAsyncOperand(
3831 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3832 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3833 getAsyncOperandsMutable()));
3836void acc::DataOp::addWaitOnly(
MLIRContext *context,
3838 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3839 effectiveDeviceTypes));
3842void acc::DataOp::addWaitOperands(
3847 if (getWaitOperandsSegments())
3848 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3850 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3851 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3852 getWaitOperandsMutable(), segments));
3853 setWaitOperandsSegments(segments);
3856 if (getHasWaitDevnumAttr())
3857 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3860 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3862 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3869LogicalResult acc::ExitDataOp::verify() {
3873 if (getDataClauseOperands().empty())
3874 return emitError(
"at least one operand must be present in dataOperands on "
3875 "the exit data operation");
3879 if (getAsyncOperand() && getAsync())
3880 return emitError(
"async attribute cannot appear with asyncOperand");
3884 if (!getWaitOperands().empty() && getWait())
3885 return emitError(
"wait attribute cannot appear with waitOperands");
3887 if (getWaitDevnum() && getWaitOperands().empty())
3888 return emitError(
"wait_devnum cannot appear without waitOperands");
3893unsigned ExitDataOp::getNumDataOperands() {
3894 return getDataClauseOperands().size();
3897Value ExitDataOp::getDataOperand(
unsigned i) {
3898 unsigned numOptional = getIfCond() ? 1 : 0;
3899 numOptional += getAsyncOperand() ? 1 : 0;
3900 numOptional += getWaitDevnum() ? 1 : 0;
3901 return getOperand(getWaitOperands().size() + numOptional + i);
3906 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3909void ExitDataOp::addAsyncOnly(
MLIRContext *context,
3911 assert(effectiveDeviceTypes.empty());
3912 assert(!getAsyncAttr());
3913 assert(!getAsyncOperand());
3915 setAsyncAttr(mlir::UnitAttr::get(context));
3918void ExitDataOp::addAsyncOperand(
3921 assert(effectiveDeviceTypes.empty());
3922 assert(!getAsyncAttr());
3923 assert(!getAsyncOperand());
3925 getAsyncOperandMutable().append(newValue);
3930 assert(effectiveDeviceTypes.empty());
3931 assert(!getWaitAttr());
3932 assert(getWaitOperands().empty());
3933 assert(!getWaitDevnum());
3935 setWaitAttr(mlir::UnitAttr::get(context));
3938void ExitDataOp::addWaitOperands(
3941 assert(effectiveDeviceTypes.empty());
3942 assert(!getWaitAttr());
3943 assert(getWaitOperands().empty());
3944 assert(!getWaitDevnum());
3949 getWaitDevnumMutable().append(newValues.front());
3950 newValues = newValues.drop_front();
3953 getWaitOperandsMutable().append(newValues);
3960LogicalResult acc::EnterDataOp::verify() {
3964 if (getDataClauseOperands().empty())
3965 return emitError(
"at least one operand must be present in dataOperands on "
3966 "the enter data operation");
3970 if (getAsyncOperand() && getAsync())
3971 return emitError(
"async attribute cannot appear with asyncOperand");
3975 if (!getWaitOperands().empty() && getWait())
3976 return emitError(
"wait attribute cannot appear with waitOperands");
3978 if (getWaitDevnum() && getWaitOperands().empty())
3979 return emitError(
"wait_devnum cannot appear without waitOperands");
3981 for (
mlir::Value operand : getDataClauseOperands())
3982 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3983 operand.getDefiningOp()))
3984 return emitError(
"expect data entry operation as defining op");
3989unsigned EnterDataOp::getNumDataOperands() {
3990 return getDataClauseOperands().size();
3993Value EnterDataOp::getDataOperand(
unsigned i) {
3994 unsigned numOptional = getIfCond() ? 1 : 0;
3995 numOptional += getAsyncOperand() ? 1 : 0;
3996 numOptional += getWaitDevnum() ? 1 : 0;
3997 return getOperand(getWaitOperands().size() + numOptional + i);
4002 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
4005void EnterDataOp::addAsyncOnly(
4007 assert(effectiveDeviceTypes.empty());
4008 assert(!getAsyncAttr());
4009 assert(!getAsyncOperand());
4011 setAsyncAttr(mlir::UnitAttr::get(context));
4014void EnterDataOp::addAsyncOperand(
4017 assert(effectiveDeviceTypes.empty());
4018 assert(!getAsyncAttr());
4019 assert(!getAsyncOperand());
4021 getAsyncOperandMutable().append(newValue);
4024void EnterDataOp::addWaitOnly(
MLIRContext *context,
4026 assert(effectiveDeviceTypes.empty());
4027 assert(!getWaitAttr());
4028 assert(getWaitOperands().empty());
4029 assert(!getWaitDevnum());
4031 setWaitAttr(mlir::UnitAttr::get(context));
4034void EnterDataOp::addWaitOperands(
4037 assert(effectiveDeviceTypes.empty());
4038 assert(!getWaitAttr());
4039 assert(getWaitOperands().empty());
4040 assert(!getWaitDevnum());
4045 getWaitDevnumMutable().append(newValues.front());
4046 newValues = newValues.drop_front();
4049 getWaitOperandsMutable().append(newValues);
4056LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
4062LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
4068LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4075 if (
Value writeVal = op.getWriteOpVal()) {
4084LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
4086LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4092AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4093 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4095 return dyn_cast<AtomicReadOp>(getSecondOp());
4098AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4099 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4101 return dyn_cast<AtomicWriteOp>(getSecondOp());
4104AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4105 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4107 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4110LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
4116template <
typename Op>
4119 bool requireAtLeastOneOperand =
true) {
4120 if (operands.empty() && requireAtLeastOneOperand)
4123 "at least one operand must appear on the declare operation");
4126 if (isa<BlockArgument>(operand) ||
4127 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4128 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4129 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4130 operand.getDefiningOp()))
4132 "expect valid declare data entry operation or acc.getdeviceptr "
4136 assert(var &&
"declare operands can only be data entry operations which "
4139 std::optional<mlir::acc::DataClause> dataClauseOptional{
4141 assert(dataClauseOptional.has_value() &&
4142 "declare operands can only be data entry operations which must have "
4144 (
void)dataClauseOptional;
4150LogicalResult acc::DeclareEnterOp::verify() {
4158LogicalResult acc::DeclareExitOp::verify() {
4169LogicalResult acc::DeclareOp::verify() {
4178 acc::DeviceType dtype) {
4179 unsigned parallelism = 0;
4180 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4181 parallelism += op.hasWorker(dtype) ? 1 : 0;
4182 parallelism += op.hasVector(dtype) ? 1 : 0;
4183 parallelism += op.hasSeq(dtype) ? 1 : 0;
4187LogicalResult acc::RoutineOp::verify() {
4188 unsigned baseParallelism =
4191 if (baseParallelism > 1)
4192 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4193 "be present at the same time";
4195 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4197 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4198 if (dtype == acc::DeviceType::None)
4202 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4203 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4204 "be present at the same time for device_type `"
4205 << acc::stringifyDeviceType(dtype) <<
"`";
4212 mlir::ArrayAttr &bindIdName,
4213 mlir::ArrayAttr &bindStrName,
4214 mlir::ArrayAttr &deviceIdTypes,
4215 mlir::ArrayAttr &deviceStrTypes) {
4222 mlir::Attribute newAttr;
4223 bool isSymbolRefAttr;
4224 auto parseResult = parser.parseAttribute(newAttr);
4225 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4226 bindIdNameAttrs.push_back(symbolRefAttr);
4227 isSymbolRefAttr = true;
4228 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4229 bindStrNameAttrs.push_back(stringAttr);
4230 isSymbolRefAttr =
false;
4235 if (isSymbolRefAttr) {
4236 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4237 parser.getContext(), mlir::acc::DeviceType::None));
4239 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4240 parser.getContext(), mlir::acc::DeviceType::None));
4243 if (isSymbolRefAttr) {
4244 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4245 parser.parseRSquare())
4248 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4249 parser.parseRSquare())
4257 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4258 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4259 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4260 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4266 std::optional<mlir::ArrayAttr> bindIdName,
4267 std::optional<mlir::ArrayAttr> bindStrName,
4268 std::optional<mlir::ArrayAttr> deviceIdTypes,
4269 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4276 allBindNames.append(bindIdName->begin(), bindIdName->end());
4277 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4282 allBindNames.append(bindStrName->begin(), bindStrName->end());
4283 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4287 if (!allBindNames.empty())
4288 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4289 [&](
const auto &pair) {
4290 p << std::get<0>(pair);
4296 mlir::ArrayAttr &gang,
4297 mlir::ArrayAttr &gangDim,
4298 mlir::ArrayAttr &gangDimDeviceTypes) {
4301 gangDimDeviceTypeAttrs;
4302 bool needCommaBeforeOperands =
false;
4306 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4307 parser.
getContext(), mlir::acc::DeviceType::None));
4308 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4315 if (parser.parseAttribute(gangAttrs.emplace_back()))
4322 needCommaBeforeOperands =
true;
4325 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4329 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4330 parser.parseColon() ||
4331 parser.parseAttribute(gangDimAttrs.emplace_back()))
4333 if (succeeded(parser.parseOptionalLSquare())) {
4334 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4335 parser.parseRSquare())
4338 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4339 parser.getContext(), mlir::acc::DeviceType::None));
4345 if (
failed(parser.parseRParen()))
4348 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4349 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4350 gangDimDeviceTypes =
4351 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4357 std::optional<mlir::ArrayAttr> gang,
4358 std::optional<mlir::ArrayAttr> gangDim,
4359 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4362 gang->size() == 1) {
4363 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4364 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4376 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4377 [&](
const auto &pair) {
4378 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4379 p << std::get<0>(pair);
4387 mlir::ArrayAttr &deviceTypes) {
4391 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4392 parser.
getContext(), mlir::acc::DeviceType::None));
4393 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4400 if (parser.parseAttribute(attributes.emplace_back()))
4408 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4414 std::optional<mlir::ArrayAttr> deviceTypes) {
4417 auto deviceTypeAttr =
4418 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4419 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4428 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4434bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4436bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4440bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4442bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4446bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4448bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4452std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4453RoutineOp::getBindNameValue() {
4454 return getBindNameValue(mlir::acc::DeviceType::None);
4457std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4458RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4461 return std::nullopt;
4464 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4465 auto attr = (*getBindIdName())[*pos];
4466 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4467 assert(symbolRefAttr &&
"expected SymbolRef");
4468 return symbolRefAttr;
4471 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4472 auto attr = (*getBindStrName())[*pos];
4473 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4474 assert(stringAttr &&
"expected String");
4478 return std::nullopt;
4481bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4483bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4487std::optional<int64_t> RoutineOp::getGangDimValue() {
4488 return getGangDimValue(mlir::acc::DeviceType::None);
4491std::optional<int64_t>
4492RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4494 return std::nullopt;
4495 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4496 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4497 return intAttr.getInt();
4499 return std::nullopt;
4504 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4505 effectiveDeviceTypes));
4510 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4511 effectiveDeviceTypes));
4516 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4517 effectiveDeviceTypes));
4522 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4523 effectiveDeviceTypes));
4532 if (getGangDimAttr())
4533 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4534 if (getGangDimDeviceTypeAttr())
4535 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4537 assert(dimValues.size() == deviceTypes.size());
4539 if (effectiveDeviceTypes.empty()) {
4540 dimValues.push_back(
4541 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4542 deviceTypes.push_back(
4543 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4545 for (DeviceType dt : effectiveDeviceTypes) {
4546 dimValues.push_back(
4547 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4548 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4551 assert(dimValues.size() == deviceTypes.size());
4553 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4554 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4557void RoutineOp::addBindStrName(
MLIRContext *context,
4559 mlir::StringAttr val) {
4560 unsigned before = getBindStrNameDeviceTypeAttr()
4561 ? getBindStrNameDeviceTypeAttr().size()
4564 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4565 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4566 unsigned after = getBindStrNameDeviceTypeAttr().size();
4569 if (getBindStrNameAttr())
4570 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4571 for (
unsigned i = 0; i < after - before; ++i)
4572 vals.push_back(val);
4574 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4577void RoutineOp::addBindIDName(
MLIRContext *context,
4579 mlir::SymbolRefAttr val) {
4581 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4583 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4584 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4585 unsigned after = getBindIdNameDeviceTypeAttr().size();
4588 if (getBindIdNameAttr())
4589 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4590 for (
unsigned i = 0; i < after - before; ++i)
4591 vals.push_back(val);
4593 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4600LogicalResult acc::InitOp::verify() {
4604 return emitOpError(
"cannot be nested in a compute operation");
4608void acc::InitOp::addDeviceType(
MLIRContext *context,
4609 mlir::acc::DeviceType deviceType) {
4611 if (getDeviceTypesAttr())
4612 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4614 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4615 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4622LogicalResult acc::ShutdownOp::verify() {
4626 return emitOpError(
"cannot be nested in a compute operation");
4630void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4631 mlir::acc::DeviceType deviceType) {
4633 if (getDeviceTypesAttr())
4634 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4636 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4637 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4644LogicalResult acc::SetOp::verify() {
4648 return emitOpError(
"cannot be nested in a compute operation");
4649 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4650 return emitOpError(
"at least one default_async, device_num, or device_type "
4651 "operand must appear");
4659LogicalResult acc::UpdateOp::verify() {
4661 if (getDataClauseOperands().empty())
4662 return emitError(
"at least one value must be present in dataOperands");
4665 getAsyncOperandsDeviceTypeAttr(),
4670 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4671 getWaitOperandsDeviceTypeAttr(),
"wait")))
4677 for (
mlir::Value operand : getDataClauseOperands())
4678 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4679 operand.getDefiningOp()))
4680 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4686unsigned UpdateOp::getNumDataOperands() {
4687 return getDataClauseOperands().size();
4690Value UpdateOp::getDataOperand(
unsigned i) {
4692 numOptional += getIfCond() ? 1 : 0;
4693 return getOperand(getWaitOperands().size() + numOptional + i);
4698 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
4701bool UpdateOp::hasAsyncOnly() {
4702 return hasAsyncOnly(mlir::acc::DeviceType::None);
4705bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4710 return getAsyncValue(mlir::acc::DeviceType::None);
4713mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4723bool UpdateOp::hasWaitOnly() {
4724 return hasWaitOnly(mlir::acc::DeviceType::None);
4727bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4732 return getWaitValues(mlir::acc::DeviceType::None);
4736UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4738 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4739 getHasWaitDevnum(), deviceType);
4743 return getWaitDevnum(mlir::acc::DeviceType::None);
4746mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4748 getWaitOperandsSegments(), getHasWaitDevnum(),
4754 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4755 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4758void UpdateOp::addAsyncOperand(
4761 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4762 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4763 getAsyncOperandsMutable()));
4768 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4769 effectiveDeviceTypes));
4772void UpdateOp::addWaitOperands(
4777 if (getWaitOperandsSegments())
4778 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4780 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4781 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4782 getWaitOperandsMutable(), segments));
4783 setWaitOperandsSegments(segments);
4786 if (getHasWaitDevnumAttr())
4787 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4790 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4792 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4799LogicalResult acc::WaitOp::verify() {
4802 if (getAsyncOperand() && getAsync())
4803 return emitError(
"async attribute cannot appear with asyncOperand");
4805 if (getWaitDevnum() && getWaitOperands().empty())
4806 return emitError(
"wait_devnum cannot appear without waitOperands");
4811#define GET_OP_CLASSES
4812#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4814#define GET_ATTRDEF_CLASSES
4815#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4817#define GET_TYPEDEF_CLASSES
4818#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4829 .Case<ACC_DATA_ENTRY_OPS>(
4830 [&](
auto entry) {
return entry.getVarPtr(); })
4831 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4832 [&](
auto exit) {
return exit.getVarPtr(); })
4850 [&](
auto entry) {
return entry.getVarType(); })
4851 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4852 [&](
auto exit) {
return exit.getVarType(); })
4862 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4863 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
4873 [&](
auto dataClause) {
return dataClause.getAccVar(); })
4882 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
4892 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4894 dataClause.getBounds().begin(), dataClause.getBounds().end());
4906 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4908 dataClause.getAsyncOperands().begin(),
4909 dataClause.getAsyncOperands().end());
4920 return dataClause.getAsyncOperandsDeviceTypeAttr();
4928 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
4935 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
4942std::optional<mlir::acc::DataClause>
4947 .Case<ACC_DATA_ENTRY_OPS>(
4948 [&](
auto entry) {
return entry.getDataClause(); })
4956 [&](
auto entry) {
return entry.getImplicit(); })
4965 [&](
auto entry) {
return entry.getDataClauseOperands(); })
4967 return dataOperands;
4975 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
4977 return dataOperands;
4984 [&](
auto entry) {
return entry.getRecipeAttr(); })
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
static bool isComputeOperation(Operation *op)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.