24#include "llvm/ADT/SmallSet.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/LogicalResult.h"
32#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
33#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
40static bool isScalarLikeType(
Type type) {
48 if (!varName.empty()) {
49 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
55struct MemRefPointerLikeModel
56 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
58 return cast<T>(pointer).getElementType();
61 mlir::acc::VariableTypeCategory
64 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
65 return mappableTy.getTypeCategory(varPtr);
67 auto memrefTy = cast<T>(pointer);
68 if (!memrefTy.hasRank()) {
71 return mlir::acc::VariableTypeCategory::uncategorized;
74 if (memrefTy.getRank() == 0) {
75 if (isScalarLikeType(memrefTy.getElementType())) {
76 return mlir::acc::VariableTypeCategory::scalar;
80 return mlir::acc::VariableTypeCategory::uncategorized;
84 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
85 return mlir::acc::VariableTypeCategory::array;
88 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
89 StringRef varName, Type varType, Value originalVar,
90 bool &needsFree)
const {
91 auto memrefTy = cast<MemRefType>(pointer);
95 if (memrefTy.hasStaticShape()) {
97 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
98 attachVarNameAttr(allocaOp, builder, varName);
99 return allocaOp.getResult();
104 if (originalVar && originalVar.
getType() == memrefTy &&
105 memrefTy.hasRank()) {
106 SmallVector<Value> dynamicSizes;
107 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
108 if (memrefTy.isDynamicDim(i)) {
112 memref::DimOp::create(builder, loc, originalVar, indexValue);
113 dynamicSizes.push_back(dimSize);
120 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
121 attachVarNameAttr(allocOp, builder, varName);
122 return allocOp.getResult();
129 bool genFree(Type pointer, OpBuilder &builder, Location loc,
131 Type varType)
const {
134 Value valueToInspect = allocRes ? allocRes : memrefValue;
137 Value currentValue = valueToInspect;
138 Operation *originalAlloc =
nullptr;
142 while (currentValue) {
145 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
146 originalAlloc = definingOp;
151 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
152 currentValue = castOp.getSource();
157 if (
auto reinterpretCastOp =
158 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
159 currentValue = reinterpretCastOp.getSource();
171 if (isa<memref::AllocaOp>(originalAlloc)) {
175 if (isa<memref::AllocOp>(originalAlloc)) {
177 memref::DeallocOp::create(builder, loc, memrefValue);
186 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
190 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
191 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
197 if (destMemref && srcMemref &&
198 destMemref.getType().getElementType() ==
199 srcMemref.getType().getElementType() &&
200 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
201 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
208 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
210 Type valueType)
const {
215 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
219 auto memrefTy = memrefValue.
getType();
222 if (memrefTy.getRank() != 0)
225 return memref::LoadOp::create(builder, loc, memrefValue);
228 bool genStore(Type pointer, OpBuilder &builder, Location loc,
234 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
238 auto memrefTy = memrefValue.getType();
241 if (memrefTy.getRank() != 0)
244 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
249struct LLVMPointerPointerLikeModel
250 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
251 LLVM::LLVMPointerType> {
254 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
256 Type valueType)
const {
261 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
264 bool genStore(Type pointer, OpBuilder &builder, Location loc,
266 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
271struct MemrefAddressOfGlobalModel
272 :
public AddressOfGlobalOpInterface::ExternalModel<
273 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
274 SymbolRefAttr getSymbol(Operation *op)
const {
275 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
276 return getGlobalOp.getNameAttr();
280struct MemrefGlobalVariableModel
281 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
283 bool isConstant(Operation *op)
const {
284 auto globalOp = cast<memref::GlobalOp>(op);
285 return globalOp.getConstant();
288 Region *getInitRegion(Operation *op)
const {
298mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
299 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
302 if (existingDeviceTypes)
303 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
305 if (newDeviceTypes.empty())
306 deviceTypes.push_back(
307 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
309 for (DeviceType dt : newDeviceTypes)
310 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
312 return mlir::ArrayAttr::get(context, deviceTypes);
321mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
322 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
327 if (existingDeviceTypes)
328 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
330 if (newDeviceTypes.empty()) {
331 argCollection.
append(arguments);
332 segments.push_back(arguments.size());
333 deviceTypes.push_back(
334 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
337 for (DeviceType dt : newDeviceTypes) {
338 argCollection.
append(arguments);
339 segments.push_back(arguments.size());
340 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
343 return mlir::ArrayAttr::get(context, deviceTypes);
347mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
348 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
352 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
353 newDeviceTypes, arguments,
354 argCollection, segments);
362void OpenACCDialect::initialize() {
365#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
368#define GET_ATTRDEF_LIST
369#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
372#define GET_TYPEDEF_LIST
373#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
379 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
381 UnrankedMemRefType::attachInterface<
382 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
383 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
387 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
389 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
417void ParallelOp::getSuccessorRegions(
429void KernelEnvironmentOp::getSuccessorRegions(
441void HostDataOp::getSuccessorRegions(
452 if (getUnstructured()) {
471 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
475 mlir::acc::DeviceType deviceType) {
479 for (
auto attr : *arrayAttr) {
480 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
481 if (deviceTypeAttr.getValue() == deviceType)
489 std::optional<mlir::ArrayAttr> deviceTypes) {
494 llvm::interleaveComma(*deviceTypes, p,
500 mlir::acc::DeviceType deviceType) {
501 unsigned segmentIdx = 0;
502 for (
auto attr : segments) {
503 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
504 if (deviceTypeAttr.getValue() == deviceType)
505 return std::make_optional(segmentIdx);
515 mlir::acc::DeviceType deviceType) {
517 return range.take_front(0);
518 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
519 int32_t nbOperandsBefore = 0;
520 for (
unsigned i = 0; i < *pos; ++i)
521 nbOperandsBefore += (*segments)[i];
522 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
524 return range.take_front(0);
531 std::optional<mlir::ArrayAttr> hasWaitDevnum,
532 mlir::acc::DeviceType deviceType) {
535 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
536 if (hasWaitDevnum->getValue()[*pos])
547 std::optional<mlir::ArrayAttr> hasWaitDevnum,
548 mlir::acc::DeviceType deviceType) {
553 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
554 if (hasWaitDevnum && *hasWaitDevnum) {
555 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
556 if (boolAttr.getValue())
557 return range.drop_front(1);
563template <
typename Op>
565 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
567 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
572 op.hasAsyncOnly(dtype))
574 "asyncOnly attribute cannot appear with asyncOperand");
579 op.hasWaitOnly(dtype))
580 return op.
emitError(
"wait attribute cannot appear with waitOperands");
585template <
typename Op>
588 return op.
emitError(
"must have var operand");
591 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
592 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
593 return op.
emitError(
"var must be mappable or pointer-like");
596 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
597 op.getVarType() == op.getVar().getType())
598 return op.
emitError(
"varType must capture the element type of var");
603template <
typename Op>
605 if (op.getVar().getType() != op.getAccVar().getType())
606 return op.
emitError(
"input and output types must match");
611template <
typename Op>
613 if (op.getModifiers() != acc::DataClauseModifier::none)
614 return op.
emitError(
"no data clause modifiers are allowed");
618template <
typename Op>
621 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
623 "invalid data clause modifiers: " +
624 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
629template <
typename OpT,
typename RecipeOpT>
630static LogicalResult
checkRecipe(OpT op, llvm::StringRef operandName) {
635 !std::is_same_v<OpT, acc::ReductionOp>)
638 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
640 return op->emitOpError() <<
"recipe expected for " << operandName;
645 return op->emitOpError()
646 <<
"expected symbol reference " << operandRecipe <<
" to point to a "
647 << operandName <<
" declaration";
668 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
689 if (failed(parser.
parseType(accVarType)))
699 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
711 mlir::TypeAttr &varTypeAttr) {
712 if (failed(parser.
parseType(varPtrType)))
723 varTypeAttr = mlir::TypeAttr::get(varType);
728 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
729 varTypeAttr = mlir::TypeAttr::get(
730 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
732 varTypeAttr = mlir::TypeAttr::get(varPtrType);
739 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
747 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
748 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
750 if (typeToCheckAgainst != varType) {
758 mlir::SymbolRefAttr &recipeAttr) {
765 mlir::SymbolRefAttr recipeAttr) {
772LogicalResult acc::DataBoundsOp::verify() {
773 auto extent = getExtent();
774 auto upperbound = getUpperbound();
775 if (!extent && !upperbound)
776 return emitError(
"expected extent or upperbound.");
783LogicalResult acc::PrivateOp::verify() {
786 "data clause associated with private operation must match its intent");
800LogicalResult acc::FirstprivateOp::verify() {
802 return emitError(
"data clause associated with firstprivate operation must "
809 *
this,
"firstprivate")))
817LogicalResult acc::FirstprivateMapInitialOp::verify() {
819 return emitError(
"data clause associated with firstprivate operation must "
831LogicalResult acc::ReductionOp::verify() {
833 return emitError(
"data clause associated with reduction operation must "
840 *
this,
"reduction")))
848LogicalResult acc::DevicePtrOp::verify() {
850 return emitError(
"data clause associated with deviceptr operation must "
864LogicalResult acc::PresentOp::verify() {
867 "data clause associated with present operation must match its intent");
880LogicalResult acc::CopyinOp::verify() {
882 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
887 "data clause associated with copyin operation must match its intent"
888 " or specify original clause this operation was decomposed from");
894 acc::DataClauseModifier::always |
895 acc::DataClauseModifier::capture)))
900bool acc::CopyinOp::isCopyinReadonly() {
901 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
902 acc::bitEnumContainsAny(getModifiers(),
903 acc::DataClauseModifier::readonly);
909LogicalResult acc::CreateOp::verify() {
916 "data clause associated with create operation must match its intent"
917 " or specify original clause this operation was decomposed from");
925 acc::DataClauseModifier::always |
926 acc::DataClauseModifier::capture)))
931bool acc::CreateOp::isCreateZero() {
933 return getDataClause() == acc::DataClause::acc_create_zero ||
935 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
941LogicalResult acc::NoCreateOp::verify() {
943 return emitError(
"data clause associated with no_create operation must "
957LogicalResult acc::AttachOp::verify() {
960 "data clause associated with attach operation must match its intent");
974LogicalResult acc::DeclareDeviceResidentOp::verify() {
975 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
976 return emitError(
"data clause associated with device_resident operation "
977 "must match its intent");
991LogicalResult acc::DeclareLinkOp::verify() {
994 "data clause associated with link operation must match its intent");
1007LogicalResult acc::CopyoutOp::verify() {
1014 "data clause associated with copyout operation must match its intent"
1015 " or specify original clause this operation was decomposed from");
1017 return emitError(
"must have both host and device pointers");
1023 acc::DataClauseModifier::always |
1024 acc::DataClauseModifier::capture)))
1029bool acc::CopyoutOp::isCopyoutZero() {
1030 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1031 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1037LogicalResult acc::DeleteOp::verify() {
1046 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1049 "data clause associated with delete operation must match its intent"
1050 " or specify original clause this operation was decomposed from");
1052 return emitError(
"must have device pointer");
1056 acc::DataClauseModifier::readonly |
1057 acc::DataClauseModifier::always |
1058 acc::DataClauseModifier::capture)))
1066LogicalResult acc::DetachOp::verify() {
1071 "data clause associated with detach operation must match its intent"
1072 " or specify original clause this operation was decomposed from");
1074 return emitError(
"must have device pointer");
1083LogicalResult acc::UpdateHostOp::verify() {
1088 "data clause associated with host operation must match its intent"
1089 " or specify original clause this operation was decomposed from");
1091 return emitError(
"must have both host and device pointers");
1104LogicalResult acc::UpdateDeviceOp::verify() {
1108 "data clause associated with device operation must match its intent"
1109 " or specify original clause this operation was decomposed from");
1122LogicalResult acc::UseDeviceOp::verify() {
1126 "data clause associated with use_device operation must match its intent"
1127 " or specify original clause this operation was decomposed from");
1140LogicalResult acc::CacheOp::verify() {
1145 "data clause associated with cache operation must match its intent"
1146 " or specify original clause this operation was decomposed from");
1156bool acc::CacheOp::isCacheReadonly() {
1157 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1158 acc::bitEnumContainsAny(getModifiers(),
1159 acc::DataClauseModifier::readonly);
1162template <
typename StructureOp>
1164 unsigned nRegions = 1) {
1167 for (
unsigned i = 0; i < nRegions; ++i)
1170 for (
Region *region : regions)
1178 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1185template <
typename OpTy>
1187 using OpRewritePattern<OpTy>::OpRewritePattern;
1189 LogicalResult matchAndRewrite(OpTy op,
1190 PatternRewriter &rewriter)
const override {
1192 Value ifCond = op.getIfCond();
1196 IntegerAttr constAttr;
1199 if (constAttr.getInt())
1200 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1212 assert(region.
hasOneBlock() &&
"expected single-block region");
1224template <
typename OpTy>
1225struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1226 using OpRewritePattern<OpTy>::OpRewritePattern;
1228 LogicalResult matchAndRewrite(OpTy op,
1229 PatternRewriter &rewriter)
const override {
1231 Value ifCond = op.getIfCond();
1235 IntegerAttr constAttr;
1238 if (constAttr.getInt())
1239 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1249struct RemoveEmptyKernelEnvironment
1251 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1253 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1254 PatternRewriter &rewriter)
const override {
1255 assert(op->getNumRegions() == 1 &&
"expected op to have one region");
1266 if (
auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1267 for (
auto attr : deviceTypeAttr) {
1268 if (
auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1269 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1276 if (
auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1277 for (
auto attr : hasDevnumAttr) {
1278 if (
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1279 if (boolAttr.getValue())
1286 if (
auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1287 if (segmentsAttr.size() > 1)
1293 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1320 for (
Value bound : bounds) {
1321 argTypes.push_back(bound.getType());
1322 argLocs.push_back(loc);
1329 Value privatizedValue;
1335 if (isa<MappableType>(varType)) {
1336 auto mappableTy = cast<MappableType>(varType);
1337 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1338 privatizedValue = mappableTy.generatePrivateInit(
1339 builder, loc, typedVar, varName, bounds, {}, needsFree);
1340 if (!privatizedValue)
1343 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1344 auto pointerLikeTy = cast<PointerLikeType>(varType);
1346 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1347 blockArgVar, needsFree);
1348 if (!privatizedValue)
1353 acc::YieldOp::create(builder, loc, privatizedValue);
1368 for (
Value bound : bounds) {
1369 copyArgTypes.push_back(bound.getType());
1370 copyArgLocs.push_back(loc);
1377 bool isMappable = isa<MappableType>(varType);
1378 bool isPointerLike = isa<PointerLikeType>(varType);
1381 if (isMappable && !isPointerLike)
1385 if (isPointerLike) {
1386 auto pointerLikeTy = cast<PointerLikeType>(varType);
1391 if (!pointerLikeTy.genCopy(
1398 acc::TerminatorOp::create(builder, loc);
1412 for (
Value bound : bounds) {
1413 destroyArgTypes.push_back(bound.getType());
1414 destroyArgLocs.push_back(loc);
1418 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1422 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1423 if (isa<MappableType>(varType)) {
1424 auto mappableTy = cast<MappableType>(varType);
1425 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
1428 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1429 auto pointerLikeTy = cast<PointerLikeType>(varType);
1430 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1434 acc::TerminatorOp::create(builder, loc);
1445 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1447 if (optional && region.
empty())
1451 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1455 return op->
emitOpError() <<
"expects " << regionName
1458 << regionType <<
" type";
1461 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1462 if (yieldOp.getOperands().size() != 1 ||
1463 yieldOp.getOperands().getTypes()[0] != type)
1464 return op->
emitOpError() <<
"expects " << regionName
1466 "yield a value of the "
1467 << regionType <<
" type";
1473LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1475 "privatization",
"init",
getType(),
1479 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1485std::optional<PrivateRecipeOp>
1487 StringRef recipeName,
Type varType,
1490 bool isMappable = isa<MappableType>(varType);
1491 bool isPointerLike = isa<PointerLikeType>(varType);
1494 if (!isMappable && !isPointerLike)
1495 return std::nullopt;
1500 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1503 bool needsFree =
false;
1504 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1505 varName, bounds, needsFree))) {
1507 return std::nullopt;
1514 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1515 Value allocRes = yieldOp.getOperand(0);
1517 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1518 varType, allocRes, bounds))) {
1520 return std::nullopt;
1527std::optional<PrivateRecipeOp>
1529 StringRef recipeName,
1530 FirstprivateRecipeOp firstprivRecipe) {
1533 auto varType = firstprivRecipe.getType();
1534 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1538 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1541 if (!firstprivRecipe.getDestroyRegion().empty()) {
1543 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1553LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1555 "privatization",
"init",
getType(),
1559 if (getCopyRegion().empty())
1560 return emitOpError() <<
"expects non-empty copy region";
1565 return emitOpError() <<
"expects copy region with two arguments of the "
1566 "privatization type";
1568 if (getDestroyRegion().empty())
1572 "privatization",
"destroy",
1579std::optional<FirstprivateRecipeOp>
1581 StringRef recipeName,
Type varType,
1584 bool isMappable = isa<MappableType>(varType);
1585 bool isPointerLike = isa<PointerLikeType>(varType);
1588 if (!isMappable && !isPointerLike)
1589 return std::nullopt;
1594 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1597 bool needsFree =
false;
1598 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1599 varName, bounds, needsFree))) {
1601 return std::nullopt;
1605 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1608 return std::nullopt;
1615 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1616 Value allocRes = yieldOp.getOperand(0);
1618 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1619 varType, allocRes, bounds))) {
1621 return std::nullopt;
1632LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1638 if (getCombinerRegion().empty())
1639 return emitOpError() <<
"expects non-empty combiner region";
1641 Block &reductionBlock = getCombinerRegion().
front();
1645 return emitOpError() <<
"expects combiner region with the first two "
1646 <<
"arguments of the reduction type";
1648 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1649 if (yieldOp.getOperands().size() != 1 ||
1650 yieldOp.getOperands().getTypes()[0] !=
getType())
1651 return emitOpError() <<
"expects combiner region to yield a value "
1652 "of the reduction type";
1663template <
typename Op>
1667 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1668 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1669 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1670 operand.getDefiningOp()))
1672 "expect data entry/exit operation or acc.getdeviceptr "
1677template <
typename OpT,
typename RecipeOpT>
1680 llvm::StringRef operandName) {
1683 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1685 <<
"expected " << operandName <<
" as defining op";
1686 if (!set.insert(operand).second)
1688 << operandName <<
" operand appears more than once";
1693unsigned ParallelOp::getNumDataOperands() {
1694 return getReductionOperands().size() + getPrivateOperands().size() +
1695 getFirstprivateOperands().size() + getDataClauseOperands().size();
1698Value ParallelOp::getDataOperand(
unsigned i) {
1700 numOptional += getNumGangs().size();
1701 numOptional += getNumWorkers().size();
1702 numOptional += getVectorLength().size();
1703 numOptional += getIfCond() ? 1 : 0;
1704 numOptional += getSelfCond() ? 1 : 0;
1705 return getOperand(getWaitOperands().size() + numOptional + i);
1708template <
typename Op>
1711 llvm::StringRef keyword) {
1712 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1713 return op.
emitOpError() << keyword <<
" operands count must match "
1714 << keyword <<
" device_type count";
1718template <
typename Op>
1721 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1722 std::size_t numOperandsInSegments = 0;
1723 std::size_t nbOfSegments = 0;
1726 for (
auto segCount : segments.
asArrayRef()) {
1727 if (maxInSegment != 0 && segCount > maxInSegment)
1728 return op.
emitOpError() << keyword <<
" expects a maximum of "
1729 << maxInSegment <<
" values per segment";
1730 numOperandsInSegments += segCount;
1735 if ((numOperandsInSegments != operands.size()) ||
1736 (!deviceTypes && !operands.empty()))
1738 << keyword <<
" operand count does not match count in segments";
1739 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1741 << keyword <<
" segment count does not match device_type count";
1745LogicalResult acc::ParallelOp::verify() {
1747 mlir::acc::PrivateRecipeOp>(
1748 *
this, getPrivateOperands(),
"private")))
1751 mlir::acc::FirstprivateRecipeOp>(
1752 *
this, getFirstprivateOperands(),
"firstprivate")))
1755 mlir::acc::ReductionRecipeOp>(
1756 *
this, getReductionOperands(),
"reduction")))
1760 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1761 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1765 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1766 getWaitOperandsDeviceTypeAttr(),
"wait")))
1770 getNumWorkersDeviceTypeAttr(),
1775 getVectorLengthDeviceTypeAttr(),
1780 getAsyncOperandsDeviceTypeAttr(),
1793 mlir::acc::DeviceType deviceType) {
1796 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1801bool acc::ParallelOp::hasAsyncOnly() {
1802 return hasAsyncOnly(mlir::acc::DeviceType::None);
1805bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1810 return getAsyncValue(mlir::acc::DeviceType::None);
1813mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1818mlir::Value acc::ParallelOp::getNumWorkersValue() {
1819 return getNumWorkersValue(mlir::acc::DeviceType::None);
1823acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1828mlir::Value acc::ParallelOp::getVectorLengthValue() {
1829 return getVectorLengthValue(mlir::acc::DeviceType::None);
1833acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1835 getVectorLength(), deviceType);
1839 return getNumGangsValues(mlir::acc::DeviceType::None);
1843ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1845 getNumGangsSegments(), deviceType);
1848bool acc::ParallelOp::hasWaitOnly() {
1849 return hasWaitOnly(mlir::acc::DeviceType::None);
1852bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1857 return getWaitValues(mlir::acc::DeviceType::None);
1861ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1863 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1864 getHasWaitDevnum(), deviceType);
1868 return getWaitDevnum(mlir::acc::DeviceType::None);
1871mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1873 getWaitOperandsSegments(), getHasWaitDevnum(),
1888 odsBuilder, odsState, asyncOperands,
nullptr,
1889 nullptr, waitOperands,
nullptr,
1891 nullptr, numGangs,
nullptr,
1892 nullptr, numWorkers,
1893 nullptr, vectorLength,
1894 nullptr, ifCond, selfCond,
1895 nullptr, reductionOperands, gangPrivateOperands,
1896 gangFirstPrivateOperands, dataClauseOperands,
1900void acc::ParallelOp::addNumWorkersOperand(
1903 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1904 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1905 getNumWorkersMutable()));
1907void acc::ParallelOp::addVectorLengthOperand(
1910 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1911 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1912 getVectorLengthMutable()));
1915void acc::ParallelOp::addAsyncOnly(
1917 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1918 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1921void acc::ParallelOp::addAsyncOperand(
1924 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1925 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1926 getAsyncOperandsMutable()));
1929void acc::ParallelOp::addNumGangsOperands(
1933 if (getNumGangsSegments())
1934 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1936 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1937 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1938 getNumGangsMutable(), segments));
1940 setNumGangsSegments(segments);
1942void acc::ParallelOp::addWaitOnly(
1944 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1945 effectiveDeviceTypes));
1947void acc::ParallelOp::addWaitOperands(
1952 if (getWaitOperandsSegments())
1953 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1955 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1956 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1957 getWaitOperandsMutable(), segments));
1958 setWaitOperandsSegments(segments);
1961 if (getHasWaitDevnumAttr())
1962 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1965 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1967 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1970void acc::ParallelOp::addPrivatization(
MLIRContext *context,
1971 mlir::acc::PrivateOp op,
1972 mlir::acc::PrivateRecipeOp recipe) {
1973 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1974 getPrivateOperandsMutable().append(op.getResult());
1977void acc::ParallelOp::addFirstPrivatization(
1978 MLIRContext *context, mlir::acc::FirstprivateOp op,
1979 mlir::acc::FirstprivateRecipeOp recipe) {
1980 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1981 getFirstprivateOperandsMutable().append(op.getResult());
1984void acc::ParallelOp::addReduction(
MLIRContext *context,
1985 mlir::acc::ReductionOp op,
1986 mlir::acc::ReductionRecipeOp recipe) {
1987 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1988 getReductionOperandsMutable().append(op.getResult());
2003 int32_t crtOperandsSize = operands.size();
2006 if (parser.parseOperand(operands.emplace_back()) ||
2007 parser.parseColonType(types.emplace_back()))
2012 seg.push_back(operands.size() - crtOperandsSize);
2022 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2023 parser.
getContext(), mlir::acc::DeviceType::None));
2029 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2036 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2037 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2038 p <<
" [" << attr <<
"]";
2043 std::optional<mlir::ArrayAttr> deviceTypes,
2044 std::optional<mlir::DenseI32ArrayAttr> segments) {
2046 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2048 llvm::interleaveComma(
2049 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2050 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2070 int32_t crtOperandsSize = operands.size();
2074 if (parser.parseOperand(operands.emplace_back()) ||
2075 parser.parseColonType(types.emplace_back()))
2081 seg.push_back(operands.size() - crtOperandsSize);
2091 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2092 parser.
getContext(), mlir::acc::DeviceType::None));
2098 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2107 std::optional<mlir::DenseI32ArrayAttr> segments) {
2109 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2111 llvm::interleaveComma(
2112 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2113 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2126 mlir::ArrayAttr &keywordOnly) {
2130 bool needCommaBeforeOperands =
false;
2134 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2135 parser.
getContext(), mlir::acc::DeviceType::None));
2136 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2143 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2150 needCommaBeforeOperands =
true;
2153 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2160 int32_t crtOperandsSize = operands.size();
2172 if (parser.parseOperand(operands.emplace_back()) ||
2173 parser.parseColonType(types.emplace_back()))
2179 seg.push_back(operands.size() - crtOperandsSize);
2189 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2190 parser.
getContext(), mlir::acc::DeviceType::None));
2197 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2198 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2200 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2208 if (attrs->size() != 1)
2210 if (
auto deviceTypeAttr =
2211 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2212 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2218 std::optional<mlir::ArrayAttr> deviceTypes,
2219 std::optional<mlir::DenseI32ArrayAttr> segments,
2220 std::optional<mlir::ArrayAttr> hasDevNum,
2221 std::optional<mlir::ArrayAttr> keywordOnly) {
2234 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2236 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2237 if (boolAttr && boolAttr.getValue())
2239 llvm::interleaveComma(
2240 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2241 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2258 if (parser.parseOperand(operands.emplace_back()) ||
2259 parser.parseColonType(types.emplace_back()))
2261 if (succeeded(parser.parseOptionalLSquare())) {
2262 if (parser.parseAttribute(attributes.emplace_back()) ||
2263 parser.parseRSquare())
2266 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2267 parser.getContext(), mlir::acc::DeviceType::None));
2274 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2281 std::optional<mlir::ArrayAttr> deviceTypes) {
2284 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2285 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2294 mlir::ArrayAttr &keywordOnlyDeviceType) {
2297 bool needCommaBeforeOperands =
false;
2301 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2302 parser.
getContext(), mlir::acc::DeviceType::None));
2303 keywordOnlyDeviceType =
2304 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2312 if (parser.parseAttribute(
2313 keywordOnlyDeviceTypeAttributes.emplace_back()))
2320 needCommaBeforeOperands =
true;
2323 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2328 if (parser.parseOperand(operands.emplace_back()) ||
2329 parser.parseColonType(types.emplace_back()))
2331 if (succeeded(parser.parseOptionalLSquare())) {
2332 if (parser.parseAttribute(attributes.emplace_back()) ||
2333 parser.parseRSquare())
2336 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2337 parser.getContext(), mlir::acc::DeviceType::None));
2343 if (
failed(parser.parseRParen()))
2348 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2355 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2357 if (operands.begin() == operands.end() &&
2373 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2374 mlir::Type &operandType, mlir::UnitAttr &attr) {
2377 attr = mlir::UnitAttr::get(parser.
getContext());
2387 if (failed(parser.
parseType(operandType)))
2397 std::optional<mlir::Value> operand,
2399 mlir::UnitAttr attr) {
2416 attr = mlir::UnitAttr::get(parser.
getContext());
2421 if (parser.parseOperand(operands.emplace_back()))
2429 if (parser.parseType(types.emplace_back()))
2444 mlir::UnitAttr attr) {
2449 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2451 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2457 mlir::acc::CombinedConstructsTypeAttr &attr) {
2459 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2460 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2462 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2463 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2465 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2466 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2469 "expected compute construct name");
2477 mlir::acc::CombinedConstructsTypeAttr attr) {
2479 switch (attr.getValue()) {
2480 case mlir::acc::CombinedConstructsType::KernelsLoop:
2483 case mlir::acc::CombinedConstructsType::ParallelLoop:
2486 case mlir::acc::CombinedConstructsType::SerialLoop:
2497unsigned SerialOp::getNumDataOperands() {
2498 return getReductionOperands().size() + getPrivateOperands().size() +
2499 getFirstprivateOperands().size() + getDataClauseOperands().size();
2502Value SerialOp::getDataOperand(
unsigned i) {
2504 numOptional += getIfCond() ? 1 : 0;
2505 numOptional += getSelfCond() ? 1 : 0;
2506 return getOperand(getWaitOperands().size() + numOptional + i);
2509bool acc::SerialOp::hasAsyncOnly() {
2510 return hasAsyncOnly(mlir::acc::DeviceType::None);
2513bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2518 return getAsyncValue(mlir::acc::DeviceType::None);
2521mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2526bool acc::SerialOp::hasWaitOnly() {
2527 return hasWaitOnly(mlir::acc::DeviceType::None);
2530bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2535 return getWaitValues(mlir::acc::DeviceType::None);
2539SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2541 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2542 getHasWaitDevnum(), deviceType);
2546 return getWaitDevnum(mlir::acc::DeviceType::None);
2549mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2551 getWaitOperandsSegments(), getHasWaitDevnum(),
2555LogicalResult acc::SerialOp::verify() {
2557 mlir::acc::PrivateRecipeOp>(
2558 *
this, getPrivateOperands(),
"private")))
2561 mlir::acc::FirstprivateRecipeOp>(
2562 *
this, getFirstprivateOperands(),
"firstprivate")))
2565 mlir::acc::ReductionRecipeOp>(
2566 *
this, getReductionOperands(),
"reduction")))
2570 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2571 getWaitOperandsDeviceTypeAttr(),
"wait")))
2575 getAsyncOperandsDeviceTypeAttr(),
2585void acc::SerialOp::addAsyncOnly(
2587 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2588 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2591void acc::SerialOp::addAsyncOperand(
2594 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2595 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2596 getAsyncOperandsMutable()));
2599void acc::SerialOp::addWaitOnly(
2601 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2602 effectiveDeviceTypes));
2604void acc::SerialOp::addWaitOperands(
2609 if (getWaitOperandsSegments())
2610 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2612 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2613 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2614 getWaitOperandsMutable(), segments));
2615 setWaitOperandsSegments(segments);
2618 if (getHasWaitDevnumAttr())
2619 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2622 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2624 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2627void acc::SerialOp::addPrivatization(
MLIRContext *context,
2628 mlir::acc::PrivateOp op,
2629 mlir::acc::PrivateRecipeOp recipe) {
2630 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2631 getPrivateOperandsMutable().append(op.getResult());
2634void acc::SerialOp::addFirstPrivatization(
2635 MLIRContext *context, mlir::acc::FirstprivateOp op,
2636 mlir::acc::FirstprivateRecipeOp recipe) {
2637 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2638 getFirstprivateOperandsMutable().append(op.getResult());
2641void acc::SerialOp::addReduction(
MLIRContext *context,
2642 mlir::acc::ReductionOp op,
2643 mlir::acc::ReductionRecipeOp recipe) {
2644 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2645 getReductionOperandsMutable().append(op.getResult());
2652unsigned KernelsOp::getNumDataOperands() {
2653 return getDataClauseOperands().size();
2656Value KernelsOp::getDataOperand(
unsigned i) {
2658 numOptional += getWaitOperands().size();
2659 numOptional += getNumGangs().size();
2660 numOptional += getNumWorkers().size();
2661 numOptional += getVectorLength().size();
2662 numOptional += getIfCond() ? 1 : 0;
2663 numOptional += getSelfCond() ? 1 : 0;
2664 return getOperand(numOptional + i);
2667bool acc::KernelsOp::hasAsyncOnly() {
2668 return hasAsyncOnly(mlir::acc::DeviceType::None);
2671bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2676 return getAsyncValue(mlir::acc::DeviceType::None);
2679mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2685 return getNumWorkersValue(mlir::acc::DeviceType::None);
2689acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2694mlir::Value acc::KernelsOp::getVectorLengthValue() {
2695 return getVectorLengthValue(mlir::acc::DeviceType::None);
2699acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2701 getVectorLength(), deviceType);
2705 return getNumGangsValues(mlir::acc::DeviceType::None);
2709KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2711 getNumGangsSegments(), deviceType);
2714bool acc::KernelsOp::hasWaitOnly() {
2715 return hasWaitOnly(mlir::acc::DeviceType::None);
2718bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2723 return getWaitValues(mlir::acc::DeviceType::None);
2727KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2729 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2730 getHasWaitDevnum(), deviceType);
2734 return getWaitDevnum(mlir::acc::DeviceType::None);
2737mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2739 getWaitOperandsSegments(), getHasWaitDevnum(),
2743LogicalResult acc::KernelsOp::verify() {
2745 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2746 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2750 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2751 getWaitOperandsDeviceTypeAttr(),
"wait")))
2755 getNumWorkersDeviceTypeAttr(),
2760 getVectorLengthDeviceTypeAttr(),
2765 getAsyncOperandsDeviceTypeAttr(),
2775void acc::KernelsOp::addPrivatization(
MLIRContext *context,
2776 mlir::acc::PrivateOp op,
2777 mlir::acc::PrivateRecipeOp recipe) {
2778 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2779 getPrivateOperandsMutable().append(op.getResult());
2782void acc::KernelsOp::addFirstPrivatization(
2783 MLIRContext *context, mlir::acc::FirstprivateOp op,
2784 mlir::acc::FirstprivateRecipeOp recipe) {
2785 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2786 getFirstprivateOperandsMutable().append(op.getResult());
2789void acc::KernelsOp::addReduction(
MLIRContext *context,
2790 mlir::acc::ReductionOp op,
2791 mlir::acc::ReductionRecipeOp recipe) {
2792 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2793 getReductionOperandsMutable().append(op.getResult());
2796void acc::KernelsOp::addNumWorkersOperand(
2799 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2800 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2801 getNumWorkersMutable()));
2804void acc::KernelsOp::addVectorLengthOperand(
2807 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2808 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2809 getVectorLengthMutable()));
2811void acc::KernelsOp::addAsyncOnly(
2813 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2814 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2817void acc::KernelsOp::addAsyncOperand(
2820 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2821 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2822 getAsyncOperandsMutable()));
2825void acc::KernelsOp::addNumGangsOperands(
2829 if (getNumGangsSegmentsAttr())
2830 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2832 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2833 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2834 getNumGangsMutable(), segments));
2836 setNumGangsSegments(segments);
2839void acc::KernelsOp::addWaitOnly(
2841 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2842 effectiveDeviceTypes));
2844void acc::KernelsOp::addWaitOperands(
2849 if (getWaitOperandsSegments())
2850 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2852 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2853 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2854 getWaitOperandsMutable(), segments));
2855 setWaitOperandsSegments(segments);
2858 if (getHasWaitDevnumAttr())
2859 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2862 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2864 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2871LogicalResult acc::HostDataOp::verify() {
2872 if (getDataClauseOperands().empty())
2873 return emitError(
"at least one operand must appear on the host_data "
2876 for (
mlir::Value operand : getDataClauseOperands())
2877 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2878 return emitError(
"expect data entry operation as defining op");
2884 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2891void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2893 results.
add<RemoveEmptyKernelEnvironment>(context);
2905 bool &needCommaBetweenValues,
bool &newValue) {
2912 attributes.push_back(gangArgType);
2913 needCommaBetweenValues =
true;
2924 mlir::ArrayAttr &gangOnlyDeviceType) {
2929 bool needCommaBetweenValues =
false;
2930 bool needCommaBeforeOperands =
false;
2934 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2935 parser.
getContext(), mlir::acc::DeviceType::None));
2936 gangOnlyDeviceType =
2937 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
2945 if (parser.parseAttribute(
2946 gangOnlyDeviceTypeAttributes.emplace_back()))
2953 needCommaBeforeOperands =
true;
2956 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
2957 mlir::acc::GangArgType::Num);
2958 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
2959 mlir::acc::GangArgType::Dim);
2960 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2961 parser.
getContext(), mlir::acc::GangArgType::Static);
2964 if (needCommaBeforeOperands) {
2965 needCommaBeforeOperands =
false;
2972 int32_t crtOperandsSize = gangOperands.size();
2974 bool newValue =
false;
2975 bool needValue =
false;
2976 if (needCommaBetweenValues) {
2984 gangOperands, gangOperandsType,
2985 gangArgTypeAttributes, argNum,
2986 needCommaBetweenValues, newValue)))
2989 gangOperands, gangOperandsType,
2990 gangArgTypeAttributes, argDim,
2991 needCommaBetweenValues, newValue)))
2993 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2994 gangOperands, gangOperandsType,
2995 gangArgTypeAttributes, argStatic,
2996 needCommaBetweenValues, newValue)))
2999 if (!newValue && needValue) {
3001 "new value expected after comma");
3009 if (gangOperands.empty())
3012 "expect at least one of num, dim or static values");
3018 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
3022 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3023 parser.
getContext(), mlir::acc::DeviceType::None));
3026 seg.push_back(gangOperands.size() - crtOperandsSize);
3034 gangArgTypeAttributes.end());
3035 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
3036 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
3039 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3040 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
3048 std::optional<mlir::ArrayAttr> gangArgTypes,
3049 std::optional<mlir::ArrayAttr> deviceTypes,
3050 std::optional<mlir::DenseI32ArrayAttr> segments,
3051 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3053 if (operands.begin() == operands.end() &&
3068 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
3070 llvm::interleaveComma(
3071 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
3072 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3073 (*gangArgTypes)[opIdx]);
3074 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3075 p << LoopOp::getGangNumKeyword();
3076 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3077 p << LoopOp::getGangDimKeyword();
3078 else if (gangArgTypeAttr.getValue() ==
3079 mlir::acc::GangArgType::Static)
3080 p << LoopOp::getGangStaticKeyword();
3081 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
3092 std::optional<mlir::ArrayAttr> segments,
3093 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3096 for (
auto attr : *segments) {
3097 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3098 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3106static std::optional<mlir::acc::DeviceType>
3108 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3110 return std::nullopt;
3111 for (
auto attr : deviceTypes) {
3112 auto deviceTypeAttr =
3113 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3114 if (!deviceTypeAttr)
3115 return mlir::acc::DeviceType::None;
3116 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3117 return deviceTypeAttr.getValue();
3119 return std::nullopt;
3122LogicalResult acc::LoopOp::verify() {
3123 if (getUpperbound().size() != getStep().size())
3124 return emitError() <<
"number of upperbounds expected to be the same as "
3127 if (getUpperbound().size() != getLowerbound().size())
3128 return emitError() <<
"number of upperbounds expected to be the same as "
3129 "number of lowerbounds";
3131 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3132 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3133 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3134 <<
" as upperbound size";
3137 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3138 return emitOpError() <<
"collapse device_type attr must be define when"
3139 <<
" collapse attr is present";
3141 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3142 getCollapseAttr().getValue().size() !=
3143 getCollapseDeviceTypeAttr().getValue().size())
3144 return emitOpError() <<
"collapse attribute count must match collapse"
3145 <<
" device_type count";
3146 if (
auto duplicateDeviceType =
checkDeviceTypes(getCollapseDeviceTypeAttr()))
3148 << acc::stringifyDeviceType(*duplicateDeviceType)
3149 <<
"` found in collapseDeviceType attribute";
3152 if (!getGangOperands().empty()) {
3153 if (!getGangOperandsArgType())
3154 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3155 <<
" when gang operands are present";
3157 if (getGangOperands().size() !=
3158 getGangOperandsArgTypeAttr().getValue().size())
3159 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3160 <<
" gangOperands count";
3162 if (getGangAttr()) {
3165 << acc::stringifyDeviceType(*duplicateDeviceType)
3166 <<
"` found in gang attribute";
3170 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3171 getGangOperandsDeviceTypeAttr(),
"gang")))
3177 << acc::stringifyDeviceType(*duplicateDeviceType)
3178 <<
"` found in worker attribute";
3179 if (
auto duplicateDeviceType =
3182 << acc::stringifyDeviceType(*duplicateDeviceType)
3183 <<
"` found in workerNumOperandsDeviceType attribute";
3185 getWorkerNumOperandsDeviceTypeAttr(),
3192 << acc::stringifyDeviceType(*duplicateDeviceType)
3193 <<
"` found in vector attribute";
3194 if (
auto duplicateDeviceType =
3197 << acc::stringifyDeviceType(*duplicateDeviceType)
3198 <<
"` found in vectorOperandsDeviceType attribute";
3200 getVectorOperandsDeviceTypeAttr(),
3205 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3206 getTileOperandsDeviceTypeAttr(),
"tile")))
3210 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3214 return emitError() <<
"only one of auto, independent, seq can be present "
3220 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3221 return attr.getValue() == mlir::acc::DeviceType::None;
3223 bool hasDefaultSeq =
3225 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3228 bool hasDefaultIndependent =
3229 getIndependentAttr()
3231 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3234 bool hasDefaultAuto =
3236 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3239 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3241 <<
"at least one of auto, independent, seq must be present";
3246 for (
auto attr : getSeqAttr()) {
3247 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3248 if (hasVector(deviceTypeAttr.getValue()) ||
3249 getVectorValue(deviceTypeAttr.getValue()) ||
3250 hasWorker(deviceTypeAttr.getValue()) ||
3251 getWorkerValue(deviceTypeAttr.getValue()) ||
3252 hasGang(deviceTypeAttr.getValue()) ||
3253 getGangValue(mlir::acc::GangArgType::Num,
3254 deviceTypeAttr.getValue()) ||
3255 getGangValue(mlir::acc::GangArgType::Dim,
3256 deviceTypeAttr.getValue()) ||
3257 getGangValue(mlir::acc::GangArgType::Static,
3258 deviceTypeAttr.getValue()))
3259 return emitError() <<
"gang, worker or vector cannot appear with seq";
3264 mlir::acc::PrivateRecipeOp>(
3265 *
this, getPrivateOperands(),
"private")))
3269 mlir::acc::FirstprivateRecipeOp>(
3270 *
this, getFirstprivateOperands(),
"firstprivate")))
3274 mlir::acc::ReductionRecipeOp>(
3275 *
this, getReductionOperands(),
"reduction")))
3278 if (getCombined().has_value() &&
3279 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3280 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3281 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3282 return emitError(
"unexpected combined constructs attribute");
3286 if (getRegion().empty())
3287 return emitError(
"expected non-empty body.");
3289 if (getUnstructured()) {
3290 if (!isContainerLike())
3292 "unstructured acc.loop must not have induction variables");
3293 }
else if (isContainerLike()) {
3297 uint64_t collapseCount = getCollapseValue().value_or(1);
3298 if (getCollapseAttr()) {
3299 for (
auto collapseEntry : getCollapseAttr()) {
3300 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3301 if (intAttr.getValue().getZExtValue() > collapseCount)
3302 collapseCount = intAttr.getValue().getZExtValue();
3310 bool foundSibling =
false;
3312 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3314 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3316 foundSibling =
true;
3321 expectedParent = op;
3324 if (collapseCount == 0)
3330 return emitError(
"found sibling loops inside container-like acc.loop");
3331 if (collapseCount != 0)
3332 return emitError(
"failed to find enough loop-like operations inside "
3333 "container-like acc.loop");
3339unsigned LoopOp::getNumDataOperands() {
3340 return getReductionOperands().size() + getPrivateOperands().size() +
3341 getFirstprivateOperands().size();
3344Value LoopOp::getDataOperand(
unsigned i) {
3345 unsigned numOptional =
3346 getLowerbound().size() + getUpperbound().size() + getStep().size();
3347 numOptional += getGangOperands().size();
3348 numOptional += getVectorOperands().size();
3349 numOptional += getWorkerNumOperands().size();
3350 numOptional += getTileOperands().size();
3351 numOptional += getCacheOperands().size();
3352 return getOperand(numOptional + i);
3355bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3357bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3361bool LoopOp::hasIndependent() {
3362 return hasIndependent(mlir::acc::DeviceType::None);
3365bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3369bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3371bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3376 return getVectorValue(mlir::acc::DeviceType::None);
3379mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3381 getVectorOperands(), deviceType);
3384bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3386bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3391 return getWorkerValue(mlir::acc::DeviceType::None);
3394mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3396 getWorkerNumOperands(), deviceType);
3399bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3401bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3406 return getTileValues(mlir::acc::DeviceType::None);
3410LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3412 getTileOperandsSegments(), deviceType);
3415std::optional<int64_t> LoopOp::getCollapseValue() {
3416 return getCollapseValue(mlir::acc::DeviceType::None);
3419std::optional<int64_t>
3420LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3421 if (!getCollapseAttr())
3422 return std::nullopt;
3423 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3425 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3426 return intAttr.getValue().getZExtValue();
3428 return std::nullopt;
3431mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3432 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3435mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3436 mlir::acc::DeviceType deviceType) {
3437 if (getGangOperands().empty())
3439 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3440 int32_t nbOperandsBefore = 0;
3441 for (
unsigned i = 0; i < *pos; ++i)
3442 nbOperandsBefore += (*getGangOperandsSegments())[i];
3445 .drop_front(nbOperandsBefore)
3446 .take_front((*getGangOperandsSegments())[*pos]);
3448 int32_t argTypeIdx = nbOperandsBefore;
3449 for (
auto value : values) {
3450 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3451 (*getGangOperandsArgType())[argTypeIdx]);
3452 if (gangArgTypeAttr.getValue() == gangArgType)
3460bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3462bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3467 return {&getRegion()};
3511 if (!regionArgs.empty()) {
3512 p << acc::LoopOp::getControlKeyword() <<
"(";
3513 llvm::interleaveComma(regionArgs, p,
3515 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3516 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3517 <<
" : " << stepType <<
") ";
3524 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3525 effectiveDeviceTypes));
3528void acc::LoopOp::addIndependent(
3530 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3531 context, getIndependentAttr(), effectiveDeviceTypes));
3536 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3537 effectiveDeviceTypes));
3540void acc::LoopOp::setCollapseForDeviceTypes(
3542 llvm::APInt value) {
3546 assert((getCollapseAttr() ==
nullptr) ==
3547 (getCollapseDeviceTypeAttr() ==
nullptr));
3548 assert(value.getBitWidth() == 64);
3550 if (getCollapseAttr()) {
3551 for (
const auto &existing :
3552 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3553 newValues.push_back(std::get<0>(existing));
3554 newDeviceTypes.push_back(std::get<1>(existing));
3558 if (effectiveDeviceTypes.empty()) {
3561 newValues.push_back(
3562 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3563 newDeviceTypes.push_back(
3564 acc::DeviceTypeAttr::get(context, DeviceType::None));
3566 for (DeviceType dt : effectiveDeviceTypes) {
3567 newValues.push_back(
3568 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3569 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3573 setCollapseAttr(ArrayAttr::get(context, newValues));
3574 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3577void acc::LoopOp::setTileForDeviceTypes(
3581 if (getTileOperandsSegments())
3582 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3584 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3585 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3586 getTileOperandsMutable(), segments));
3588 setTileOperandsSegments(segments);
3591void acc::LoopOp::addVectorOperand(
3594 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3595 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3596 newValue, getVectorOperandsMutable()));
3599void acc::LoopOp::addEmptyVector(
3601 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3602 effectiveDeviceTypes));
3605void acc::LoopOp::addWorkerNumOperand(
3608 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3609 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3610 newValue, getWorkerNumOperandsMutable()));
3613void acc::LoopOp::addEmptyWorker(
3615 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3616 effectiveDeviceTypes));
3619void acc::LoopOp::addEmptyGang(
3621 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3622 effectiveDeviceTypes));
3625bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3626 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3627 return attr.getValue() == dt;
3629 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3630 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3633 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3635 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3637 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3643bool acc::LoopOp::hasDefaultGangWorkerVector() {
3644 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3645 hasGang() || getGangValue(GangArgType::Num) ||
3646 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3650acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3651 if (hasSeq(deviceType))
3652 return LoopParMode::loop_seq;
3653 if (hasAuto(deviceType))
3654 return LoopParMode::loop_auto;
3655 if (hasIndependent(deviceType))
3656 return LoopParMode::loop_independent;
3658 return LoopParMode::loop_seq;
3660 return LoopParMode::loop_auto;
3661 assert(hasIndependent() &&
3662 "loop must have default auto, seq, or independent");
3663 return LoopParMode::loop_independent;
3666void acc::LoopOp::addGangOperands(
3671 getGangOperandsSegments())
3672 llvm::copy(*existingSegments, std::back_inserter(segments));
3674 unsigned beforeCount = segments.size();
3676 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3677 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3678 getGangOperandsMutable(), segments));
3680 setGangOperandsSegments(segments);
3687 unsigned numAdded = segments.size() - beforeCount;
3691 if (getGangOperandsArgTypeAttr())
3692 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3694 for (
auto i : llvm::index_range(0u, numAdded)) {
3695 llvm::transform(argTypes, std::back_inserter(gangTypes),
3696 [=](mlir::acc::GangArgType gangTy) {
3697 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3702 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3706void acc::LoopOp::addPrivatization(
MLIRContext *context,
3707 mlir::acc::PrivateOp op,
3708 mlir::acc::PrivateRecipeOp recipe) {
3709 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3710 getPrivateOperandsMutable().append(op.getResult());
3713void acc::LoopOp::addFirstPrivatization(
3714 MLIRContext *context, mlir::acc::FirstprivateOp op,
3715 mlir::acc::FirstprivateRecipeOp recipe) {
3716 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3717 getFirstprivateOperandsMutable().append(op.getResult());
3720void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3721 mlir::acc::ReductionRecipeOp recipe) {
3722 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3723 getReductionOperandsMutable().append(op.getResult());
3730LogicalResult acc::DataOp::verify() {
3735 return emitError(
"at least one operand or the default attribute "
3736 "must appear on the data operation");
3738 for (
mlir::Value operand : getDataClauseOperands())
3739 if (isa<BlockArgument>(operand) ||
3740 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3741 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3742 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3743 operand.getDefiningOp()))
3744 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3753unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3755Value DataOp::getDataOperand(
unsigned i) {
3756 unsigned numOptional = getIfCond() ? 1 : 0;
3758 numOptional += getWaitOperands().size();
3759 return getOperand(numOptional + i);
3762bool acc::DataOp::hasAsyncOnly() {
3763 return hasAsyncOnly(mlir::acc::DeviceType::None);
3766bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3771 return getAsyncValue(mlir::acc::DeviceType::None);
3774mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3779bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
3781bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3786 return getWaitValues(mlir::acc::DeviceType::None);
3790DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3792 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3793 getHasWaitDevnum(), deviceType);
3797 return getWaitDevnum(mlir::acc::DeviceType::None);
3800mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3802 getWaitOperandsSegments(), getHasWaitDevnum(),
3806void acc::DataOp::addAsyncOnly(
3808 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3809 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3812void acc::DataOp::addAsyncOperand(
3815 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3816 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3817 getAsyncOperandsMutable()));
3820void acc::DataOp::addWaitOnly(
MLIRContext *context,
3822 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3823 effectiveDeviceTypes));
3826void acc::DataOp::addWaitOperands(
3831 if (getWaitOperandsSegments())
3832 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3834 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3835 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3836 getWaitOperandsMutable(), segments));
3837 setWaitOperandsSegments(segments);
3840 if (getHasWaitDevnumAttr())
3841 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3844 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3846 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3853LogicalResult acc::ExitDataOp::verify() {
3857 if (getDataClauseOperands().empty())
3858 return emitError(
"at least one operand must be present in dataOperands on "
3859 "the exit data operation");
3863 if (getAsyncOperand() && getAsync())
3864 return emitError(
"async attribute cannot appear with asyncOperand");
3868 if (!getWaitOperands().empty() && getWait())
3869 return emitError(
"wait attribute cannot appear with waitOperands");
3871 if (getWaitDevnum() && getWaitOperands().empty())
3872 return emitError(
"wait_devnum cannot appear without waitOperands");
3877unsigned ExitDataOp::getNumDataOperands() {
3878 return getDataClauseOperands().size();
3881Value ExitDataOp::getDataOperand(
unsigned i) {
3882 unsigned numOptional = getIfCond() ? 1 : 0;
3883 numOptional += getAsyncOperand() ? 1 : 0;
3884 numOptional += getWaitDevnum() ? 1 : 0;
3885 return getOperand(getWaitOperands().size() + numOptional + i);
3890 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3893void ExitDataOp::addAsyncOnly(
MLIRContext *context,
3895 assert(effectiveDeviceTypes.empty());
3896 assert(!getAsyncAttr());
3897 assert(!getAsyncOperand());
3899 setAsyncAttr(mlir::UnitAttr::get(context));
3902void ExitDataOp::addAsyncOperand(
3905 assert(effectiveDeviceTypes.empty());
3906 assert(!getAsyncAttr());
3907 assert(!getAsyncOperand());
3909 getAsyncOperandMutable().append(newValue);
3914 assert(effectiveDeviceTypes.empty());
3915 assert(!getWaitAttr());
3916 assert(getWaitOperands().empty());
3917 assert(!getWaitDevnum());
3919 setWaitAttr(mlir::UnitAttr::get(context));
3922void ExitDataOp::addWaitOperands(
3925 assert(effectiveDeviceTypes.empty());
3926 assert(!getWaitAttr());
3927 assert(getWaitOperands().empty());
3928 assert(!getWaitDevnum());
3933 getWaitDevnumMutable().append(newValues.front());
3934 newValues = newValues.drop_front();
3937 getWaitOperandsMutable().append(newValues);
3944LogicalResult acc::EnterDataOp::verify() {
3948 if (getDataClauseOperands().empty())
3949 return emitError(
"at least one operand must be present in dataOperands on "
3950 "the enter data operation");
3954 if (getAsyncOperand() && getAsync())
3955 return emitError(
"async attribute cannot appear with asyncOperand");
3959 if (!getWaitOperands().empty() && getWait())
3960 return emitError(
"wait attribute cannot appear with waitOperands");
3962 if (getWaitDevnum() && getWaitOperands().empty())
3963 return emitError(
"wait_devnum cannot appear without waitOperands");
3965 for (
mlir::Value operand : getDataClauseOperands())
3966 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3967 operand.getDefiningOp()))
3968 return emitError(
"expect data entry operation as defining op");
3973unsigned EnterDataOp::getNumDataOperands() {
3974 return getDataClauseOperands().size();
3977Value EnterDataOp::getDataOperand(
unsigned i) {
3978 unsigned numOptional = getIfCond() ? 1 : 0;
3979 numOptional += getAsyncOperand() ? 1 : 0;
3980 numOptional += getWaitDevnum() ? 1 : 0;
3981 return getOperand(getWaitOperands().size() + numOptional + i);
3986 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
3989void EnterDataOp::addAsyncOnly(
3991 assert(effectiveDeviceTypes.empty());
3992 assert(!getAsyncAttr());
3993 assert(!getAsyncOperand());
3995 setAsyncAttr(mlir::UnitAttr::get(context));
3998void EnterDataOp::addAsyncOperand(
4001 assert(effectiveDeviceTypes.empty());
4002 assert(!getAsyncAttr());
4003 assert(!getAsyncOperand());
4005 getAsyncOperandMutable().append(newValue);
4008void EnterDataOp::addWaitOnly(
MLIRContext *context,
4010 assert(effectiveDeviceTypes.empty());
4011 assert(!getWaitAttr());
4012 assert(getWaitOperands().empty());
4013 assert(!getWaitDevnum());
4015 setWaitAttr(mlir::UnitAttr::get(context));
4018void EnterDataOp::addWaitOperands(
4021 assert(effectiveDeviceTypes.empty());
4022 assert(!getWaitAttr());
4023 assert(getWaitOperands().empty());
4024 assert(!getWaitDevnum());
4029 getWaitDevnumMutable().append(newValues.front());
4030 newValues = newValues.drop_front();
4033 getWaitOperandsMutable().append(newValues);
4040LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
4046LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
4052LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4059 if (
Value writeVal = op.getWriteOpVal()) {
4068LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
4070LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4076AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4077 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4079 return dyn_cast<AtomicReadOp>(getSecondOp());
4082AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4083 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4085 return dyn_cast<AtomicWriteOp>(getSecondOp());
4088AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4089 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4091 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4094LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
4100template <
typename Op>
4103 bool requireAtLeastOneOperand =
true) {
4104 if (operands.empty() && requireAtLeastOneOperand)
4107 "at least one operand must appear on the declare operation");
4110 if (isa<BlockArgument>(operand) ||
4111 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4112 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4113 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4114 operand.getDefiningOp()))
4116 "expect valid declare data entry operation or acc.getdeviceptr "
4120 assert(var &&
"declare operands can only be data entry operations which "
4123 std::optional<mlir::acc::DataClause> dataClauseOptional{
4125 assert(dataClauseOptional.has_value() &&
4126 "declare operands can only be data entry operations which must have "
4128 (
void)dataClauseOptional;
4134LogicalResult acc::DeclareEnterOp::verify() {
4142LogicalResult acc::DeclareExitOp::verify() {
4153LogicalResult acc::DeclareOp::verify() {
4162 acc::DeviceType dtype) {
4163 unsigned parallelism = 0;
4164 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4165 parallelism += op.hasWorker(dtype) ? 1 : 0;
4166 parallelism += op.hasVector(dtype) ? 1 : 0;
4167 parallelism += op.hasSeq(dtype) ? 1 : 0;
4171LogicalResult acc::RoutineOp::verify() {
4172 unsigned baseParallelism =
4175 if (baseParallelism > 1)
4176 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4177 "be present at the same time";
4179 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4181 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4182 if (dtype == acc::DeviceType::None)
4186 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4187 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4188 "be present at the same time for device_type `"
4189 << acc::stringifyDeviceType(dtype) <<
"`";
4196 mlir::ArrayAttr &bindIdName,
4197 mlir::ArrayAttr &bindStrName,
4198 mlir::ArrayAttr &deviceIdTypes,
4199 mlir::ArrayAttr &deviceStrTypes) {
4206 mlir::Attribute newAttr;
4207 bool isSymbolRefAttr;
4208 auto parseResult = parser.parseAttribute(newAttr);
4209 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4210 bindIdNameAttrs.push_back(symbolRefAttr);
4211 isSymbolRefAttr = true;
4212 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4213 bindStrNameAttrs.push_back(stringAttr);
4214 isSymbolRefAttr =
false;
4219 if (isSymbolRefAttr) {
4220 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4221 parser.getContext(), mlir::acc::DeviceType::None));
4223 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4224 parser.getContext(), mlir::acc::DeviceType::None));
4227 if (isSymbolRefAttr) {
4228 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4229 parser.parseRSquare())
4232 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4233 parser.parseRSquare())
4241 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4242 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4243 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4244 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4250 std::optional<mlir::ArrayAttr> bindIdName,
4251 std::optional<mlir::ArrayAttr> bindStrName,
4252 std::optional<mlir::ArrayAttr> deviceIdTypes,
4253 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4260 allBindNames.append(bindIdName->begin(), bindIdName->end());
4261 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4266 allBindNames.append(bindStrName->begin(), bindStrName->end());
4267 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4271 if (!allBindNames.empty())
4272 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4273 [&](
const auto &pair) {
4274 p << std::get<0>(pair);
4280 mlir::ArrayAttr &gang,
4281 mlir::ArrayAttr &gangDim,
4282 mlir::ArrayAttr &gangDimDeviceTypes) {
4285 gangDimDeviceTypeAttrs;
4286 bool needCommaBeforeOperands =
false;
4290 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4291 parser.
getContext(), mlir::acc::DeviceType::None));
4292 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4299 if (parser.parseAttribute(gangAttrs.emplace_back()))
4306 needCommaBeforeOperands =
true;
4309 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4313 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4314 parser.parseColon() ||
4315 parser.parseAttribute(gangDimAttrs.emplace_back()))
4317 if (succeeded(parser.parseOptionalLSquare())) {
4318 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4319 parser.parseRSquare())
4322 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4323 parser.getContext(), mlir::acc::DeviceType::None));
4329 if (
failed(parser.parseRParen()))
4332 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4333 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4334 gangDimDeviceTypes =
4335 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4341 std::optional<mlir::ArrayAttr> gang,
4342 std::optional<mlir::ArrayAttr> gangDim,
4343 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4346 gang->size() == 1) {
4347 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4348 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4360 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4361 [&](
const auto &pair) {
4362 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4363 p << std::get<0>(pair);
4371 mlir::ArrayAttr &deviceTypes) {
4375 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4376 parser.
getContext(), mlir::acc::DeviceType::None));
4377 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4384 if (parser.parseAttribute(attributes.emplace_back()))
4392 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4398 std::optional<mlir::ArrayAttr> deviceTypes) {
4401 auto deviceTypeAttr =
4402 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4403 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4412 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4418bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4420bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4424bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4426bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4430bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4432bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4436std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4437RoutineOp::getBindNameValue() {
4438 return getBindNameValue(mlir::acc::DeviceType::None);
4441std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4442RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4445 return std::nullopt;
4448 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4449 auto attr = (*getBindIdName())[*pos];
4450 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4451 assert(symbolRefAttr &&
"expected SymbolRef");
4452 return symbolRefAttr;
4455 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4456 auto attr = (*getBindStrName())[*pos];
4457 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4458 assert(stringAttr &&
"expected String");
4462 return std::nullopt;
4465bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4467bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4471std::optional<int64_t> RoutineOp::getGangDimValue() {
4472 return getGangDimValue(mlir::acc::DeviceType::None);
4475std::optional<int64_t>
4476RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4478 return std::nullopt;
4479 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4480 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4481 return intAttr.getInt();
4483 return std::nullopt;
4488 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4489 effectiveDeviceTypes));
4494 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4495 effectiveDeviceTypes));
4500 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4501 effectiveDeviceTypes));
4506 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4507 effectiveDeviceTypes));
4516 if (getGangDimAttr())
4517 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4518 if (getGangDimDeviceTypeAttr())
4519 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4521 assert(dimValues.size() == deviceTypes.size());
4523 if (effectiveDeviceTypes.empty()) {
4524 dimValues.push_back(
4525 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4526 deviceTypes.push_back(
4527 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4529 for (DeviceType dt : effectiveDeviceTypes) {
4530 dimValues.push_back(
4531 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4532 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4535 assert(dimValues.size() == deviceTypes.size());
4537 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4538 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4541void RoutineOp::addBindStrName(
MLIRContext *context,
4543 mlir::StringAttr val) {
4544 unsigned before = getBindStrNameDeviceTypeAttr()
4545 ? getBindStrNameDeviceTypeAttr().size()
4548 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4549 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4550 unsigned after = getBindStrNameDeviceTypeAttr().size();
4553 if (getBindStrNameAttr())
4554 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4555 for (
unsigned i = 0; i < after - before; ++i)
4556 vals.push_back(val);
4558 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4561void RoutineOp::addBindIDName(
MLIRContext *context,
4563 mlir::SymbolRefAttr val) {
4565 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4567 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4568 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4569 unsigned after = getBindIdNameDeviceTypeAttr().size();
4572 if (getBindIdNameAttr())
4573 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4574 for (
unsigned i = 0; i < after - before; ++i)
4575 vals.push_back(val);
4577 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4584LogicalResult acc::InitOp::verify() {
4588 return emitOpError(
"cannot be nested in a compute operation");
4592void acc::InitOp::addDeviceType(
MLIRContext *context,
4593 mlir::acc::DeviceType deviceType) {
4595 if (getDeviceTypesAttr())
4596 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4598 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4599 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4606LogicalResult acc::ShutdownOp::verify() {
4610 return emitOpError(
"cannot be nested in a compute operation");
4614void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4615 mlir::acc::DeviceType deviceType) {
4617 if (getDeviceTypesAttr())
4618 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4620 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4621 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4628LogicalResult acc::SetOp::verify() {
4632 return emitOpError(
"cannot be nested in a compute operation");
4633 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4634 return emitOpError(
"at least one default_async, device_num, or device_type "
4635 "operand must appear");
4643LogicalResult acc::UpdateOp::verify() {
4645 if (getDataClauseOperands().empty())
4646 return emitError(
"at least one value must be present in dataOperands");
4649 getAsyncOperandsDeviceTypeAttr(),
4654 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4655 getWaitOperandsDeviceTypeAttr(),
"wait")))
4661 for (
mlir::Value operand : getDataClauseOperands())
4662 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4663 operand.getDefiningOp()))
4664 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4670unsigned UpdateOp::getNumDataOperands() {
4671 return getDataClauseOperands().size();
4674Value UpdateOp::getDataOperand(
unsigned i) {
4676 numOptional += getIfCond() ? 1 : 0;
4677 return getOperand(getWaitOperands().size() + numOptional + i);
4682 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
4685bool UpdateOp::hasAsyncOnly() {
4686 return hasAsyncOnly(mlir::acc::DeviceType::None);
4689bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4694 return getAsyncValue(mlir::acc::DeviceType::None);
4697mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4707bool UpdateOp::hasWaitOnly() {
4708 return hasWaitOnly(mlir::acc::DeviceType::None);
4711bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4716 return getWaitValues(mlir::acc::DeviceType::None);
4720UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4722 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4723 getHasWaitDevnum(), deviceType);
4727 return getWaitDevnum(mlir::acc::DeviceType::None);
4730mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4732 getWaitOperandsSegments(), getHasWaitDevnum(),
4738 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4739 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4742void UpdateOp::addAsyncOperand(
4745 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4746 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4747 getAsyncOperandsMutable()));
4752 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4753 effectiveDeviceTypes));
4756void UpdateOp::addWaitOperands(
4761 if (getWaitOperandsSegments())
4762 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4764 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4765 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4766 getWaitOperandsMutable(), segments));
4767 setWaitOperandsSegments(segments);
4770 if (getHasWaitDevnumAttr())
4771 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4774 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4776 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4783LogicalResult acc::WaitOp::verify() {
4786 if (getAsyncOperand() && getAsync())
4787 return emitError(
"async attribute cannot appear with asyncOperand");
4789 if (getWaitDevnum() && getWaitOperands().empty())
4790 return emitError(
"wait_devnum cannot appear without waitOperands");
4795#define GET_OP_CLASSES
4796#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4798#define GET_ATTRDEF_CLASSES
4799#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4801#define GET_TYPEDEF_CLASSES
4802#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4813 .Case<ACC_DATA_ENTRY_OPS>(
4814 [&](
auto entry) {
return entry.getVarPtr(); })
4815 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4816 [&](
auto exit) {
return exit.getVarPtr(); })
4834 [&](
auto entry) {
return entry.getVarType(); })
4835 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4836 [&](
auto exit) {
return exit.getVarType(); })
4846 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4847 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
4857 [&](
auto dataClause) {
return dataClause.getAccVar(); })
4866 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
4876 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4878 dataClause.getBounds().begin(), dataClause.getBounds().end());
4890 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4892 dataClause.getAsyncOperands().begin(),
4893 dataClause.getAsyncOperands().end());
4904 return dataClause.getAsyncOperandsDeviceTypeAttr();
4912 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
4919 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
4926std::optional<mlir::acc::DataClause>
4931 .Case<ACC_DATA_ENTRY_OPS>(
4932 [&](
auto entry) {
return entry.getDataClause(); })
4940 [&](
auto entry) {
return entry.getImplicit(); })
4949 [&](
auto entry) {
return entry.getDataClauseOperands(); })
4951 return dataOperands;
4959 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
4961 return dataOperands;
4968 [&](
auto entry) {
return entry.getRecipeAttr(); })
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
static bool isComputeOperation(Operation *op)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.