25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
41static bool isScalarLikeType(
Type type) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
56struct MemRefPointerLikeModel
57 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
59 return cast<T>(pointer).getElementType();
62 mlir::acc::VariableTypeCategory
65 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
72 return mlir::acc::VariableTypeCategory::uncategorized;
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
81 return mlir::acc::VariableTypeCategory::uncategorized;
85 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree)
const {
92 auto memrefTy = cast<MemRefType>(pointer);
96 if (memrefTy.hasStaticShape()) {
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
105 if (originalVar && originalVar.
getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
132 Type varType)
const {
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc =
nullptr;
143 while (currentValue) {
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
152 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
158 if (
auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
172 if (isa<memref::AllocaOp>(originalAlloc)) {
176 if (isa<memref::AllocOp>(originalAlloc)) {
178 memref::DeallocOp::create(builder, loc, memrefValue);
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
209 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType)
const {
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
220 auto memrefTy = memrefValue.
getType();
223 if (memrefTy.getRank() != 0)
226 return memref::LoadOp::create(builder, loc, memrefValue);
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
239 auto memrefTy = memrefValue.getType();
242 if (memrefTy.getRank() != 0)
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
249 bool isDeviceData(Type pointer, Value var)
const {
250 auto memrefTy = cast<T>(pointer);
251 Attribute memSpace = memrefTy.getMemorySpace();
252 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
256struct LLVMPointerPointerLikeModel
257 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
258 LLVM::LLVMPointerType> {
261 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
263 Type valueType)
const {
268 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
271 bool genStore(Type pointer, OpBuilder &builder, Location loc,
273 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
278struct MemrefAddressOfGlobalModel
279 :
public AddressOfGlobalOpInterface::ExternalModel<
280 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
281 SymbolRefAttr getSymbol(Operation *op)
const {
282 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
283 return getGlobalOp.getNameAttr();
287struct MemrefGlobalVariableModel
288 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
290 bool isConstant(Operation *op)
const {
291 auto globalOp = cast<memref::GlobalOp>(op);
292 return globalOp.getConstant();
295 Region *getInitRegion(Operation *op)
const {
300 bool isDeviceData(Operation *op)
const {
301 auto globalOp = cast<memref::GlobalOp>(op);
302 Attribute memSpace = globalOp.getType().getMemorySpace();
303 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
307struct GPULaunchOffloadRegionModel
308 :
public acc::OffloadRegionOpInterface::ExternalModel<
309 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
310 mlir::Region &getOffloadRegion(mlir::Operation *op)
const {
311 return cast<gpu::LaunchOp>(op).getBody();
319mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
320 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
323 if (existingDeviceTypes)
324 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
326 if (newDeviceTypes.empty())
327 deviceTypes.push_back(
328 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
330 for (DeviceType dt : newDeviceTypes)
331 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
333 return mlir::ArrayAttr::get(context, deviceTypes);
342mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
343 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
348 if (existingDeviceTypes)
349 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
351 if (newDeviceTypes.empty()) {
352 argCollection.
append(arguments);
353 segments.push_back(arguments.size());
354 deviceTypes.push_back(
355 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
358 for (DeviceType dt : newDeviceTypes) {
359 argCollection.
append(arguments);
360 segments.push_back(arguments.size());
361 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
364 return mlir::ArrayAttr::get(context, deviceTypes);
368mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
369 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
373 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
374 newDeviceTypes, arguments,
375 argCollection, segments);
383void OpenACCDialect::initialize() {
386#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
393#define GET_TYPEDEF_LIST
394#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
400 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
402 UnrankedMemRefType::attachInterface<
403 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
404 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
408 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
410 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
411 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*
getContext());
448void ParallelOp::getSuccessorRegions(
478void HostDataOp::getSuccessorRegions(
493 if (getUnstructured()) {
526 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
530 mlir::acc::DeviceType deviceType) {
534 for (
auto attr : *arrayAttr) {
535 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
536 if (deviceTypeAttr.getValue() == deviceType)
544 std::optional<mlir::ArrayAttr> deviceTypes) {
549 llvm::interleaveComma(*deviceTypes, p,
555 mlir::acc::DeviceType deviceType) {
556 unsigned segmentIdx = 0;
557 for (
auto attr : segments) {
558 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
559 if (deviceTypeAttr.getValue() == deviceType)
560 return std::make_optional(segmentIdx);
570 mlir::acc::DeviceType deviceType) {
572 return range.take_front(0);
573 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
574 int32_t nbOperandsBefore = 0;
575 for (
unsigned i = 0; i < *pos; ++i)
576 nbOperandsBefore += (*segments)[i];
577 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
579 return range.take_front(0);
586 std::optional<mlir::ArrayAttr> hasWaitDevnum,
587 mlir::acc::DeviceType deviceType) {
590 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
591 if (hasWaitDevnum->getValue()[*pos])
602 std::optional<mlir::ArrayAttr> hasWaitDevnum,
603 mlir::acc::DeviceType deviceType) {
608 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
609 if (hasWaitDevnum && *hasWaitDevnum) {
610 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
611 if (boolAttr.getValue())
612 return range.drop_front(1);
618template <
typename Op>
620 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
622 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
627 op.hasAsyncOnly(dtype))
629 "asyncOnly attribute cannot appear with asyncOperand");
634 op.hasWaitOnly(dtype))
635 return op.
emitError(
"wait attribute cannot appear with waitOperands");
640template <
typename Op>
643 return op.
emitError(
"must have var operand");
646 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
647 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
648 return op.
emitError(
"var must be mappable or pointer-like");
651 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
652 op.getVarType() == op.getVar().getType())
653 return op.
emitError(
"varType must capture the element type of var");
658template <
typename Op>
660 if (op.getVar().getType() != op.getAccVar().getType())
661 return op.
emitError(
"input and output types must match");
666template <
typename Op>
668 if (op.getModifiers() != acc::DataClauseModifier::none)
669 return op.
emitError(
"no data clause modifiers are allowed");
673template <
typename Op>
676 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
678 "invalid data clause modifiers: " +
679 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
684template <
typename OpT,
typename RecipeOpT>
685static LogicalResult
checkRecipe(OpT op, llvm::StringRef operandName) {
690 !std::is_same_v<OpT, acc::ReductionOp>)
693 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
695 return op->emitOpError() <<
"recipe expected for " << operandName;
700 return op->emitOpError()
701 <<
"expected symbol reference " << operandRecipe <<
" to point to a "
702 << operandName <<
" declaration";
723 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
744 if (failed(parser.
parseType(accVarType)))
754 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
766 mlir::TypeAttr &varTypeAttr) {
767 if (failed(parser.
parseType(varPtrType)))
778 varTypeAttr = mlir::TypeAttr::get(varType);
783 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
784 varTypeAttr = mlir::TypeAttr::get(
785 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
787 varTypeAttr = mlir::TypeAttr::get(varPtrType);
794 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
802 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
803 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
805 if (typeToCheckAgainst != varType) {
813 mlir::SymbolRefAttr &recipeAttr) {
820 mlir::SymbolRefAttr recipeAttr) {
827LogicalResult acc::DataBoundsOp::verify() {
828 auto extent = getExtent();
829 auto upperbound = getUpperbound();
830 if (!extent && !upperbound)
831 return emitError(
"expected extent or upperbound.");
838LogicalResult acc::PrivateOp::verify() {
841 "data clause associated with private operation must match its intent");
855LogicalResult acc::FirstprivateOp::verify() {
857 return emitError(
"data clause associated with firstprivate operation must "
864 *
this,
"firstprivate")))
872LogicalResult acc::ReductionOp::verify() {
874 return emitError(
"data clause associated with reduction operation must "
881 *
this,
"reduction")))
889LogicalResult acc::DevicePtrOp::verify() {
891 return emitError(
"data clause associated with deviceptr operation must "
905LogicalResult acc::PresentOp::verify() {
908 "data clause associated with present operation must match its intent");
921LogicalResult acc::CopyinOp::verify() {
923 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
928 "data clause associated with copyin operation must match its intent"
929 " or specify original clause this operation was decomposed from");
935 acc::DataClauseModifier::always |
936 acc::DataClauseModifier::capture)))
941bool acc::CopyinOp::isCopyinReadonly() {
942 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
943 acc::bitEnumContainsAny(getModifiers(),
944 acc::DataClauseModifier::readonly);
950LogicalResult acc::CreateOp::verify() {
957 "data clause associated with create operation must match its intent"
958 " or specify original clause this operation was decomposed from");
966 acc::DataClauseModifier::always |
967 acc::DataClauseModifier::capture)))
972bool acc::CreateOp::isCreateZero() {
974 return getDataClause() == acc::DataClause::acc_create_zero ||
976 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
982LogicalResult acc::NoCreateOp::verify() {
984 return emitError(
"data clause associated with no_create operation must "
998LogicalResult acc::AttachOp::verify() {
1001 "data clause associated with attach operation must match its intent");
1015LogicalResult acc::DeclareDeviceResidentOp::verify() {
1016 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
1017 return emitError(
"data clause associated with device_resident operation "
1018 "must match its intent");
1032LogicalResult acc::DeclareLinkOp::verify() {
1035 "data clause associated with link operation must match its intent");
1048LogicalResult acc::CopyoutOp::verify() {
1055 "data clause associated with copyout operation must match its intent"
1056 " or specify original clause this operation was decomposed from");
1058 return emitError(
"must have both host and device pointers");
1064 acc::DataClauseModifier::always |
1065 acc::DataClauseModifier::capture)))
1070bool acc::CopyoutOp::isCopyoutZero() {
1071 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1072 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1078LogicalResult acc::DeleteOp::verify() {
1087 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1090 "data clause associated with delete operation must match its intent"
1091 " or specify original clause this operation was decomposed from");
1093 return emitError(
"must have device pointer");
1097 acc::DataClauseModifier::readonly |
1098 acc::DataClauseModifier::always |
1099 acc::DataClauseModifier::capture)))
1107LogicalResult acc::DetachOp::verify() {
1112 "data clause associated with detach operation must match its intent"
1113 " or specify original clause this operation was decomposed from");
1115 return emitError(
"must have device pointer");
1124LogicalResult acc::UpdateHostOp::verify() {
1129 "data clause associated with host operation must match its intent"
1130 " or specify original clause this operation was decomposed from");
1132 return emitError(
"must have both host and device pointers");
1145LogicalResult acc::UpdateDeviceOp::verify() {
1149 "data clause associated with device operation must match its intent"
1150 " or specify original clause this operation was decomposed from");
1163LogicalResult acc::UseDeviceOp::verify() {
1167 "data clause associated with use_device operation must match its intent"
1168 " or specify original clause this operation was decomposed from");
1181LogicalResult acc::CacheOp::verify() {
1186 "data clause associated with cache operation must match its intent"
1187 " or specify original clause this operation was decomposed from");
1197bool acc::CacheOp::isCacheReadonly() {
1198 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1199 acc::bitEnumContainsAny(getModifiers(),
1200 acc::DataClauseModifier::readonly);
1214 if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
1222template <
typename EffectTy>
1227 for (
unsigned i = 0, e = operand.
size(); i < e; ++i)
1228 effects.emplace_back(EffectTy::get(), &operand[i]);
1232template <
typename EffectTy>
1237 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(
result));
1241void acc::PrivateOp::getEffects(
1255void acc::FirstprivateOp::getEffects(
1269void acc::ReductionOp::getEffects(
1283void acc::DevicePtrOp::getEffects(
1292void acc::PresentOp::getEffects(
1303void acc::CopyinOp::getEffects(
1316void acc::CreateOp::getEffects(
1329void acc::NoCreateOp::getEffects(
1340void acc::AttachOp::getEffects(
1353void acc::GetDevicePtrOp::getEffects(
1362void acc::UpdateDeviceOp::getEffects(
1372void acc::UseDeviceOp::getEffects(
1381void acc::DeclareDeviceResidentOp::getEffects(
1392void acc::DeclareLinkOp::getEffects(
1403void acc::CacheOp::getEffects(
1408void acc::CopyoutOp::getEffects(
1421void acc::DeleteOp::getEffects(
1433void acc::DetachOp::getEffects(
1445void acc::UpdateHostOp::getEffects(
1457template <
typename StructureOp>
1459 unsigned nRegions = 1) {
1462 for (
unsigned i = 0; i < nRegions; ++i)
1465 for (
Region *region : regions)
1473 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1480template <
typename OpTy>
1482 using OpRewritePattern<OpTy>::OpRewritePattern;
1484 LogicalResult matchAndRewrite(OpTy op,
1485 PatternRewriter &rewriter)
const override {
1487 Value ifCond = op.getIfCond();
1491 IntegerAttr constAttr;
1494 if (constAttr.getInt())
1495 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1507 assert(region.
hasOneBlock() &&
"expected single-block region");
1519template <
typename OpTy>
1520struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1521 using OpRewritePattern<OpTy>::OpRewritePattern;
1523 LogicalResult matchAndRewrite(OpTy op,
1524 PatternRewriter &rewriter)
const override {
1526 Value ifCond = op.getIfCond();
1530 IntegerAttr constAttr;
1533 if (constAttr.getInt())
1534 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1556 for (
Value bound : bounds) {
1557 argTypes.push_back(bound.getType());
1558 argLocs.push_back(loc);
1565 Value privatizedValue;
1571 if (isa<MappableType>(varType)) {
1572 auto mappableTy = cast<MappableType>(varType);
1573 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1574 privatizedValue = mappableTy.generatePrivateInit(
1575 builder, loc, typedVar, varName, bounds, {}, needsFree);
1576 if (!privatizedValue)
1579 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1580 auto pointerLikeTy = cast<PointerLikeType>(varType);
1582 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1583 blockArgVar, needsFree);
1584 if (!privatizedValue)
1589 acc::YieldOp::create(builder, loc, privatizedValue);
1604 for (
Value bound : bounds) {
1605 copyArgTypes.push_back(bound.getType());
1606 copyArgLocs.push_back(loc);
1613 bool isMappable = isa<MappableType>(varType);
1614 bool isPointerLike = isa<PointerLikeType>(varType);
1617 if (isMappable && !isPointerLike)
1621 if (isPointerLike) {
1622 auto pointerLikeTy = cast<PointerLikeType>(varType);
1627 if (!pointerLikeTy.genCopy(
1634 acc::TerminatorOp::create(builder, loc);
1648 for (
Value bound : bounds) {
1649 destroyArgTypes.push_back(bound.getType());
1650 destroyArgLocs.push_back(loc);
1654 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1658 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1659 if (isa<MappableType>(varType)) {
1660 auto mappableTy = cast<MappableType>(varType);
1661 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds))
1664 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1665 auto pointerLikeTy = cast<PointerLikeType>(varType);
1666 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1670 acc::TerminatorOp::create(builder, loc);
1681 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1683 if (optional && region.
empty())
1687 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1691 return op->
emitOpError() <<
"expects " << regionName
1694 << regionType <<
" type";
1697 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1698 if (yieldOp.getOperands().size() != 1 ||
1699 yieldOp.getOperands().getTypes()[0] != type)
1700 return op->
emitOpError() <<
"expects " << regionName
1702 "yield a value of the "
1703 << regionType <<
" type";
1709LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1711 "privatization",
"init",
getType(),
1715 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1721std::optional<PrivateRecipeOp>
1723 StringRef recipeName,
Type varType,
1726 bool isMappable = isa<MappableType>(varType);
1727 bool isPointerLike = isa<PointerLikeType>(varType);
1730 if (!isMappable && !isPointerLike)
1731 return std::nullopt;
1736 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1739 bool needsFree =
false;
1740 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1741 varName, bounds, needsFree))) {
1743 return std::nullopt;
1750 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1751 Value allocRes = yieldOp.getOperand(0);
1753 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1754 varType, allocRes, bounds))) {
1756 return std::nullopt;
1763std::optional<PrivateRecipeOp>
1765 StringRef recipeName,
1766 FirstprivateRecipeOp firstprivRecipe) {
1769 auto varType = firstprivRecipe.getType();
1770 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1774 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1777 if (!firstprivRecipe.getDestroyRegion().empty()) {
1779 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1789LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1791 "privatization",
"init",
getType(),
1795 if (getCopyRegion().empty())
1796 return emitOpError() <<
"expects non-empty copy region";
1801 return emitOpError() <<
"expects copy region with two arguments of the "
1802 "privatization type";
1804 if (getDestroyRegion().empty())
1808 "privatization",
"destroy",
1815std::optional<FirstprivateRecipeOp>
1817 StringRef recipeName,
Type varType,
1820 bool isMappable = isa<MappableType>(varType);
1821 bool isPointerLike = isa<PointerLikeType>(varType);
1824 if (!isMappable && !isPointerLike)
1825 return std::nullopt;
1830 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1833 bool needsFree =
false;
1834 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1835 varName, bounds, needsFree))) {
1837 return std::nullopt;
1841 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1844 return std::nullopt;
1851 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1852 Value allocRes = yieldOp.getOperand(0);
1854 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1855 varType, allocRes, bounds))) {
1857 return std::nullopt;
1868LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1874 if (getCombinerRegion().empty())
1875 return emitOpError() <<
"expects non-empty combiner region";
1877 Block &reductionBlock = getCombinerRegion().
front();
1881 return emitOpError() <<
"expects combiner region with the first two "
1882 <<
"arguments of the reduction type";
1884 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1885 if (yieldOp.getOperands().size() != 1 ||
1886 yieldOp.getOperands().getTypes()[0] !=
getType())
1887 return emitOpError() <<
"expects combiner region to yield a value "
1888 "of the reduction type";
1899template <
typename Op>
1903 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1904 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1905 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1906 operand.getDefiningOp()))
1908 "expect data entry/exit operation or acc.getdeviceptr "
1913template <
typename OpT,
typename RecipeOpT>
1916 llvm::StringRef operandName) {
1919 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1921 <<
"expected " << operandName <<
" as defining op";
1922 if (!set.insert(operand).second)
1924 << operandName <<
" operand appears more than once";
1929unsigned ParallelOp::getNumDataOperands() {
1930 return getReductionOperands().size() + getPrivateOperands().size() +
1931 getFirstprivateOperands().size() + getDataClauseOperands().size();
1934Value ParallelOp::getDataOperand(
unsigned i) {
1936 numOptional += getNumGangs().size();
1937 numOptional += getNumWorkers().size();
1938 numOptional += getVectorLength().size();
1939 numOptional += getIfCond() ? 1 : 0;
1940 numOptional += getSelfCond() ? 1 : 0;
1941 return getOperand(getWaitOperands().size() + numOptional + i);
1944template <
typename Op>
1947 llvm::StringRef keyword) {
1948 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1949 return op.
emitOpError() << keyword <<
" operands count must match "
1950 << keyword <<
" device_type count";
1954template <
typename Op>
1957 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1958 std::size_t numOperandsInSegments = 0;
1959 std::size_t nbOfSegments = 0;
1962 for (
auto segCount : segments.
asArrayRef()) {
1963 if (maxInSegment != 0 && segCount > maxInSegment)
1964 return op.
emitOpError() << keyword <<
" expects a maximum of "
1965 << maxInSegment <<
" values per segment";
1966 numOperandsInSegments += segCount;
1971 if ((numOperandsInSegments != operands.size()) ||
1972 (!deviceTypes && !operands.empty()))
1974 << keyword <<
" operand count does not match count in segments";
1975 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1977 << keyword <<
" segment count does not match device_type count";
1981LogicalResult acc::ParallelOp::verify() {
1983 mlir::acc::PrivateRecipeOp>(
1984 *
this, getPrivateOperands(),
"private")))
1987 mlir::acc::FirstprivateRecipeOp>(
1988 *
this, getFirstprivateOperands(),
"firstprivate")))
1991 mlir::acc::ReductionRecipeOp>(
1992 *
this, getReductionOperands(),
"reduction")))
1996 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1997 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2001 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2002 getWaitOperandsDeviceTypeAttr(),
"wait")))
2006 getNumWorkersDeviceTypeAttr(),
2011 getVectorLengthDeviceTypeAttr(),
2016 getAsyncOperandsDeviceTypeAttr(),
2029 mlir::acc::DeviceType deviceType) {
2032 if (
auto pos =
findSegment(*arrayAttr, deviceType))
2037bool acc::ParallelOp::hasAsyncOnly() {
2038 return hasAsyncOnly(mlir::acc::DeviceType::None);
2041bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2046 return getAsyncValue(mlir::acc::DeviceType::None);
2049mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2054mlir::Value acc::ParallelOp::getNumWorkersValue() {
2055 return getNumWorkersValue(mlir::acc::DeviceType::None);
2059acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2064mlir::Value acc::ParallelOp::getVectorLengthValue() {
2065 return getVectorLengthValue(mlir::acc::DeviceType::None);
2069acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2071 getVectorLength(), deviceType);
2075 return getNumGangsValues(mlir::acc::DeviceType::None);
2079ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2081 getNumGangsSegments(), deviceType);
2084bool acc::ParallelOp::hasWaitOnly() {
2085 return hasWaitOnly(mlir::acc::DeviceType::None);
2088bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2093 return getWaitValues(mlir::acc::DeviceType::None);
2097ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2099 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2100 getHasWaitDevnum(), deviceType);
2104 return getWaitDevnum(mlir::acc::DeviceType::None);
2107mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2109 getWaitOperandsSegments(), getHasWaitDevnum(),
2124 odsBuilder, odsState, asyncOperands,
nullptr,
2125 nullptr, waitOperands,
nullptr,
2127 nullptr, numGangs,
nullptr,
2128 nullptr, numWorkers,
2129 nullptr, vectorLength,
2130 nullptr, ifCond, selfCond,
2131 nullptr, reductionOperands, gangPrivateOperands,
2132 gangFirstPrivateOperands, dataClauseOperands,
2136void acc::ParallelOp::addNumWorkersOperand(
2139 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2140 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2141 getNumWorkersMutable()));
2143void acc::ParallelOp::addVectorLengthOperand(
2146 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2147 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2148 getVectorLengthMutable()));
2151void acc::ParallelOp::addAsyncOnly(
2153 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2154 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2157void acc::ParallelOp::addAsyncOperand(
2160 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2161 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2162 getAsyncOperandsMutable()));
2165void acc::ParallelOp::addNumGangsOperands(
2169 if (getNumGangsSegments())
2170 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2172 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2173 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2174 getNumGangsMutable(), segments));
2176 setNumGangsSegments(segments);
2178void acc::ParallelOp::addWaitOnly(
2180 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2181 effectiveDeviceTypes));
2183void acc::ParallelOp::addWaitOperands(
2188 if (getWaitOperandsSegments())
2189 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2191 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2192 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2193 getWaitOperandsMutable(), segments));
2194 setWaitOperandsSegments(segments);
2197 if (getHasWaitDevnumAttr())
2198 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2201 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2203 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2206void acc::ParallelOp::addPrivatization(
MLIRContext *context,
2207 mlir::acc::PrivateOp op,
2208 mlir::acc::PrivateRecipeOp recipe) {
2209 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2210 getPrivateOperandsMutable().append(op.getResult());
2213void acc::ParallelOp::addFirstPrivatization(
2214 MLIRContext *context, mlir::acc::FirstprivateOp op,
2215 mlir::acc::FirstprivateRecipeOp recipe) {
2216 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2217 getFirstprivateOperandsMutable().append(op.getResult());
2220void acc::ParallelOp::addReduction(
MLIRContext *context,
2221 mlir::acc::ReductionOp op,
2222 mlir::acc::ReductionRecipeOp recipe) {
2223 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2224 getReductionOperandsMutable().append(op.getResult());
2239 int32_t crtOperandsSize = operands.size();
2242 if (parser.parseOperand(operands.emplace_back()) ||
2243 parser.parseColonType(types.emplace_back()))
2248 seg.push_back(operands.size() - crtOperandsSize);
2258 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2259 parser.
getContext(), mlir::acc::DeviceType::None));
2265 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2272 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2273 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2274 p <<
" [" << attr <<
"]";
2279 std::optional<mlir::ArrayAttr> deviceTypes,
2280 std::optional<mlir::DenseI32ArrayAttr> segments) {
2282 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2284 llvm::interleaveComma(
2285 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2286 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2306 int32_t crtOperandsSize = operands.size();
2310 if (parser.parseOperand(operands.emplace_back()) ||
2311 parser.parseColonType(types.emplace_back()))
2317 seg.push_back(operands.size() - crtOperandsSize);
2327 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2328 parser.
getContext(), mlir::acc::DeviceType::None));
2334 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2343 std::optional<mlir::DenseI32ArrayAttr> segments) {
2345 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2347 llvm::interleaveComma(
2348 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2349 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2362 mlir::ArrayAttr &keywordOnly) {
2366 bool needCommaBeforeOperands =
false;
2370 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2371 parser.
getContext(), mlir::acc::DeviceType::None));
2372 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2379 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2386 needCommaBeforeOperands =
true;
2389 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2396 int32_t crtOperandsSize = operands.size();
2408 if (parser.parseOperand(operands.emplace_back()) ||
2409 parser.parseColonType(types.emplace_back()))
2415 seg.push_back(operands.size() - crtOperandsSize);
2425 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2426 parser.
getContext(), mlir::acc::DeviceType::None));
2433 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2434 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2436 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2444 if (attrs->size() != 1)
2446 if (
auto deviceTypeAttr =
2447 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2448 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2454 std::optional<mlir::ArrayAttr> deviceTypes,
2455 std::optional<mlir::DenseI32ArrayAttr> segments,
2456 std::optional<mlir::ArrayAttr> hasDevNum,
2457 std::optional<mlir::ArrayAttr> keywordOnly) {
2470 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2472 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2473 if (boolAttr && boolAttr.getValue())
2475 llvm::interleaveComma(
2476 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2477 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2494 if (parser.parseOperand(operands.emplace_back()) ||
2495 parser.parseColonType(types.emplace_back()))
2497 if (succeeded(parser.parseOptionalLSquare())) {
2498 if (parser.parseAttribute(attributes.emplace_back()) ||
2499 parser.parseRSquare())
2502 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2503 parser.getContext(), mlir::acc::DeviceType::None));
2510 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2517 std::optional<mlir::ArrayAttr> deviceTypes) {
2520 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2521 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2530 mlir::ArrayAttr &keywordOnlyDeviceType) {
2533 bool needCommaBeforeOperands =
false;
2537 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2538 parser.
getContext(), mlir::acc::DeviceType::None));
2539 keywordOnlyDeviceType =
2540 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2548 if (parser.parseAttribute(
2549 keywordOnlyDeviceTypeAttributes.emplace_back()))
2556 needCommaBeforeOperands =
true;
2559 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2564 if (parser.parseOperand(operands.emplace_back()) ||
2565 parser.parseColonType(types.emplace_back()))
2567 if (succeeded(parser.parseOptionalLSquare())) {
2568 if (parser.parseAttribute(attributes.emplace_back()) ||
2569 parser.parseRSquare())
2572 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2573 parser.getContext(), mlir::acc::DeviceType::None));
2579 if (
failed(parser.parseRParen()))
2584 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2591 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2593 if (operands.begin() == operands.end() &&
2609 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2610 mlir::Type &operandType, mlir::UnitAttr &attr) {
2613 attr = mlir::UnitAttr::get(parser.
getContext());
2623 if (failed(parser.
parseType(operandType)))
2633 std::optional<mlir::Value> operand,
2635 mlir::UnitAttr attr) {
2652 attr = mlir::UnitAttr::get(parser.
getContext());
2657 if (parser.parseOperand(operands.emplace_back()))
2665 if (parser.parseType(types.emplace_back()))
2680 mlir::UnitAttr attr) {
2685 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2687 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2693 mlir::acc::CombinedConstructsTypeAttr &attr) {
2695 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2696 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2698 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2699 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2701 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2702 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2705 "expected compute construct name");
2713 mlir::acc::CombinedConstructsTypeAttr attr) {
2715 switch (attr.getValue()) {
2716 case mlir::acc::CombinedConstructsType::KernelsLoop:
2719 case mlir::acc::CombinedConstructsType::ParallelLoop:
2722 case mlir::acc::CombinedConstructsType::SerialLoop:
2733unsigned SerialOp::getNumDataOperands() {
2734 return getReductionOperands().size() + getPrivateOperands().size() +
2735 getFirstprivateOperands().size() + getDataClauseOperands().size();
2738Value SerialOp::getDataOperand(
unsigned i) {
2740 numOptional += getIfCond() ? 1 : 0;
2741 numOptional += getSelfCond() ? 1 : 0;
2742 return getOperand(getWaitOperands().size() + numOptional + i);
2745bool acc::SerialOp::hasAsyncOnly() {
2746 return hasAsyncOnly(mlir::acc::DeviceType::None);
2749bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2754 return getAsyncValue(mlir::acc::DeviceType::None);
2757mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2762bool acc::SerialOp::hasWaitOnly() {
2763 return hasWaitOnly(mlir::acc::DeviceType::None);
2766bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2771 return getWaitValues(mlir::acc::DeviceType::None);
2775SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2777 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2778 getHasWaitDevnum(), deviceType);
2782 return getWaitDevnum(mlir::acc::DeviceType::None);
2785mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2787 getWaitOperandsSegments(), getHasWaitDevnum(),
2791LogicalResult acc::SerialOp::verify() {
2793 mlir::acc::PrivateRecipeOp>(
2794 *
this, getPrivateOperands(),
"private")))
2797 mlir::acc::FirstprivateRecipeOp>(
2798 *
this, getFirstprivateOperands(),
"firstprivate")))
2801 mlir::acc::ReductionRecipeOp>(
2802 *
this, getReductionOperands(),
"reduction")))
2806 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2807 getWaitOperandsDeviceTypeAttr(),
"wait")))
2811 getAsyncOperandsDeviceTypeAttr(),
2821void acc::SerialOp::addAsyncOnly(
2823 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2824 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2827void acc::SerialOp::addAsyncOperand(
2830 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2831 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2832 getAsyncOperandsMutable()));
2835void acc::SerialOp::addWaitOnly(
2837 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2838 effectiveDeviceTypes));
2840void acc::SerialOp::addWaitOperands(
2845 if (getWaitOperandsSegments())
2846 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2848 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2849 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2850 getWaitOperandsMutable(), segments));
2851 setWaitOperandsSegments(segments);
2854 if (getHasWaitDevnumAttr())
2855 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2858 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2860 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2863void acc::SerialOp::addPrivatization(
MLIRContext *context,
2864 mlir::acc::PrivateOp op,
2865 mlir::acc::PrivateRecipeOp recipe) {
2866 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2867 getPrivateOperandsMutable().append(op.getResult());
2870void acc::SerialOp::addFirstPrivatization(
2871 MLIRContext *context, mlir::acc::FirstprivateOp op,
2872 mlir::acc::FirstprivateRecipeOp recipe) {
2873 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2874 getFirstprivateOperandsMutable().append(op.getResult());
2877void acc::SerialOp::addReduction(
MLIRContext *context,
2878 mlir::acc::ReductionOp op,
2879 mlir::acc::ReductionRecipeOp recipe) {
2880 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2881 getReductionOperandsMutable().append(op.getResult());
2888unsigned KernelsOp::getNumDataOperands() {
2889 return getDataClauseOperands().size();
2892Value KernelsOp::getDataOperand(
unsigned i) {
2894 numOptional += getWaitOperands().size();
2895 numOptional += getNumGangs().size();
2896 numOptional += getNumWorkers().size();
2897 numOptional += getVectorLength().size();
2898 numOptional += getIfCond() ? 1 : 0;
2899 numOptional += getSelfCond() ? 1 : 0;
2900 return getOperand(numOptional + i);
2903bool acc::KernelsOp::hasAsyncOnly() {
2904 return hasAsyncOnly(mlir::acc::DeviceType::None);
2907bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2912 return getAsyncValue(mlir::acc::DeviceType::None);
2915mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2921 return getNumWorkersValue(mlir::acc::DeviceType::None);
2925acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2930mlir::Value acc::KernelsOp::getVectorLengthValue() {
2931 return getVectorLengthValue(mlir::acc::DeviceType::None);
2935acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2937 getVectorLength(), deviceType);
2941 return getNumGangsValues(mlir::acc::DeviceType::None);
2945KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2947 getNumGangsSegments(), deviceType);
2950bool acc::KernelsOp::hasWaitOnly() {
2951 return hasWaitOnly(mlir::acc::DeviceType::None);
2954bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2959 return getWaitValues(mlir::acc::DeviceType::None);
2963KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2965 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2966 getHasWaitDevnum(), deviceType);
2970 return getWaitDevnum(mlir::acc::DeviceType::None);
2973mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2975 getWaitOperandsSegments(), getHasWaitDevnum(),
2979LogicalResult acc::KernelsOp::verify() {
2981 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2982 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2986 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2987 getWaitOperandsDeviceTypeAttr(),
"wait")))
2991 getNumWorkersDeviceTypeAttr(),
2996 getVectorLengthDeviceTypeAttr(),
3001 getAsyncOperandsDeviceTypeAttr(),
3011void acc::KernelsOp::addPrivatization(
MLIRContext *context,
3012 mlir::acc::PrivateOp op,
3013 mlir::acc::PrivateRecipeOp recipe) {
3014 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3015 getPrivateOperandsMutable().append(op.getResult());
3018void acc::KernelsOp::addFirstPrivatization(
3019 MLIRContext *context, mlir::acc::FirstprivateOp op,
3020 mlir::acc::FirstprivateRecipeOp recipe) {
3021 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3022 getFirstprivateOperandsMutable().append(op.getResult());
3025void acc::KernelsOp::addReduction(
MLIRContext *context,
3026 mlir::acc::ReductionOp op,
3027 mlir::acc::ReductionRecipeOp recipe) {
3028 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3029 getReductionOperandsMutable().append(op.getResult());
3032void acc::KernelsOp::addNumWorkersOperand(
3035 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3036 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3037 getNumWorkersMutable()));
3040void acc::KernelsOp::addVectorLengthOperand(
3043 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3044 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3045 getVectorLengthMutable()));
3047void acc::KernelsOp::addAsyncOnly(
3049 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3050 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3053void acc::KernelsOp::addAsyncOperand(
3056 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3057 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3058 getAsyncOperandsMutable()));
3061void acc::KernelsOp::addNumGangsOperands(
3065 if (getNumGangsSegmentsAttr())
3066 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3068 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3069 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3070 getNumGangsMutable(), segments));
3072 setNumGangsSegments(segments);
3075void acc::KernelsOp::addWaitOnly(
3077 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3078 effectiveDeviceTypes));
3080void acc::KernelsOp::addWaitOperands(
3085 if (getWaitOperandsSegments())
3086 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3088 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3089 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3090 getWaitOperandsMutable(), segments));
3091 setWaitOperandsSegments(segments);
3094 if (getHasWaitDevnumAttr())
3095 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3098 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3100 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3107LogicalResult acc::HostDataOp::verify() {
3108 if (getDataClauseOperands().empty())
3109 return emitError(
"at least one operand must appear on the host_data "
3113 for (
mlir::Value operand : getDataClauseOperands()) {
3115 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3117 return emitError(
"expect data entry operation as defining op");
3120 if (!seenVars.insert(useDeviceOp.getVar()).second)
3121 return emitError(
"duplicate use_device variable");
3128 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3140 bool &needCommaBetweenValues,
bool &newValue) {
3147 attributes.push_back(gangArgType);
3148 needCommaBetweenValues =
true;
3159 mlir::ArrayAttr &gangOnlyDeviceType) {
3164 bool needCommaBetweenValues =
false;
3165 bool needCommaBeforeOperands =
false;
3169 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3170 parser.
getContext(), mlir::acc::DeviceType::None));
3171 gangOnlyDeviceType =
3172 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
3180 if (parser.parseAttribute(
3181 gangOnlyDeviceTypeAttributes.emplace_back()))
3188 needCommaBeforeOperands =
true;
3191 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3192 mlir::acc::GangArgType::Num);
3193 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3194 mlir::acc::GangArgType::Dim);
3195 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3196 parser.
getContext(), mlir::acc::GangArgType::Static);
3199 if (needCommaBeforeOperands) {
3200 needCommaBeforeOperands =
false;
3207 int32_t crtOperandsSize = gangOperands.size();
3209 bool newValue =
false;
3210 bool needValue =
false;
3211 if (needCommaBetweenValues) {
3219 gangOperands, gangOperandsType,
3220 gangArgTypeAttributes, argNum,
3221 needCommaBetweenValues, newValue)))
3224 gangOperands, gangOperandsType,
3225 gangArgTypeAttributes, argDim,
3226 needCommaBetweenValues, newValue)))
3228 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3229 gangOperands, gangOperandsType,
3230 gangArgTypeAttributes, argStatic,
3231 needCommaBetweenValues, newValue)))
3234 if (!newValue && needValue) {
3236 "new value expected after comma");
3244 if (gangOperands.empty())
3247 "expect at least one of num, dim or static values");
3253 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
3257 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3258 parser.
getContext(), mlir::acc::DeviceType::None));
3261 seg.push_back(gangOperands.size() - crtOperandsSize);
3269 gangArgTypeAttributes.end());
3270 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
3271 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
3274 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3275 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
3283 std::optional<mlir::ArrayAttr> gangArgTypes,
3284 std::optional<mlir::ArrayAttr> deviceTypes,
3285 std::optional<mlir::DenseI32ArrayAttr> segments,
3286 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3288 if (operands.begin() == operands.end() &&
3303 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
3305 llvm::interleaveComma(
3306 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
3307 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3308 (*gangArgTypes)[opIdx]);
3309 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3310 p << LoopOp::getGangNumKeyword();
3311 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3312 p << LoopOp::getGangDimKeyword();
3313 else if (gangArgTypeAttr.getValue() ==
3314 mlir::acc::GangArgType::Static)
3315 p << LoopOp::getGangStaticKeyword();
3316 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
3327 std::optional<mlir::ArrayAttr> segments,
3328 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3331 for (
auto attr : *segments) {
3332 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3333 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3341static std::optional<mlir::acc::DeviceType>
3343 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3345 return std::nullopt;
3346 for (
auto attr : deviceTypes) {
3347 auto deviceTypeAttr =
3348 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3349 if (!deviceTypeAttr)
3350 return mlir::acc::DeviceType::None;
3351 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3352 return deviceTypeAttr.getValue();
3354 return std::nullopt;
3357LogicalResult acc::LoopOp::verify() {
3358 if (getUpperbound().size() != getStep().size())
3359 return emitError() <<
"number of upperbounds expected to be the same as "
3362 if (getUpperbound().size() != getLowerbound().size())
3363 return emitError() <<
"number of upperbounds expected to be the same as "
3364 "number of lowerbounds";
3366 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3367 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3368 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3369 <<
" as upperbound size";
3372 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3373 return emitOpError() <<
"collapse device_type attr must be define when"
3374 <<
" collapse attr is present";
3376 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3377 getCollapseAttr().getValue().size() !=
3378 getCollapseDeviceTypeAttr().getValue().size())
3379 return emitOpError() <<
"collapse attribute count must match collapse"
3380 <<
" device_type count";
3381 if (
auto duplicateDeviceType =
checkDeviceTypes(getCollapseDeviceTypeAttr()))
3383 << acc::stringifyDeviceType(*duplicateDeviceType)
3384 <<
"` found in collapseDeviceType attribute";
3387 if (!getGangOperands().empty()) {
3388 if (!getGangOperandsArgType())
3389 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3390 <<
" when gang operands are present";
3392 if (getGangOperands().size() !=
3393 getGangOperandsArgTypeAttr().getValue().size())
3394 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3395 <<
" gangOperands count";
3397 if (getGangAttr()) {
3400 << acc::stringifyDeviceType(*duplicateDeviceType)
3401 <<
"` found in gang attribute";
3405 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3406 getGangOperandsDeviceTypeAttr(),
"gang")))
3412 << acc::stringifyDeviceType(*duplicateDeviceType)
3413 <<
"` found in worker attribute";
3414 if (
auto duplicateDeviceType =
3417 << acc::stringifyDeviceType(*duplicateDeviceType)
3418 <<
"` found in workerNumOperandsDeviceType attribute";
3420 getWorkerNumOperandsDeviceTypeAttr(),
3427 << acc::stringifyDeviceType(*duplicateDeviceType)
3428 <<
"` found in vector attribute";
3429 if (
auto duplicateDeviceType =
3432 << acc::stringifyDeviceType(*duplicateDeviceType)
3433 <<
"` found in vectorOperandsDeviceType attribute";
3435 getVectorOperandsDeviceTypeAttr(),
3440 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3441 getTileOperandsDeviceTypeAttr(),
"tile")))
3445 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3449 return emitError() <<
"only one of auto, independent, seq can be present "
3455 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3456 return attr.getValue() == mlir::acc::DeviceType::None;
3458 bool hasDefaultSeq =
3460 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3463 bool hasDefaultIndependent =
3464 getIndependentAttr()
3466 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3469 bool hasDefaultAuto =
3471 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3474 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3476 <<
"at least one of auto, independent, seq must be present";
3481 for (
auto attr : getSeqAttr()) {
3482 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3483 if (hasVector(deviceTypeAttr.getValue()) ||
3484 getVectorValue(deviceTypeAttr.getValue()) ||
3485 hasWorker(deviceTypeAttr.getValue()) ||
3486 getWorkerValue(deviceTypeAttr.getValue()) ||
3487 hasGang(deviceTypeAttr.getValue()) ||
3488 getGangValue(mlir::acc::GangArgType::Num,
3489 deviceTypeAttr.getValue()) ||
3490 getGangValue(mlir::acc::GangArgType::Dim,
3491 deviceTypeAttr.getValue()) ||
3492 getGangValue(mlir::acc::GangArgType::Static,
3493 deviceTypeAttr.getValue()))
3494 return emitError() <<
"gang, worker or vector cannot appear with seq";
3499 mlir::acc::PrivateRecipeOp>(
3500 *
this, getPrivateOperands(),
"private")))
3504 mlir::acc::FirstprivateRecipeOp>(
3505 *
this, getFirstprivateOperands(),
"firstprivate")))
3509 mlir::acc::ReductionRecipeOp>(
3510 *
this, getReductionOperands(),
"reduction")))
3513 if (getCombined().has_value() &&
3514 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3515 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3516 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3517 return emitError(
"unexpected combined constructs attribute");
3521 if (getRegion().empty())
3522 return emitError(
"expected non-empty body.");
3524 if (getUnstructured()) {
3525 if (!isContainerLike())
3527 "unstructured acc.loop must not have induction variables");
3528 }
else if (isContainerLike()) {
3532 uint64_t collapseCount = getCollapseValue().value_or(1);
3533 if (getCollapseAttr()) {
3534 for (
auto collapseEntry : getCollapseAttr()) {
3535 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3536 if (intAttr.getValue().getZExtValue() > collapseCount)
3537 collapseCount = intAttr.getValue().getZExtValue();
3545 bool foundSibling =
false;
3547 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3549 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3551 foundSibling =
true;
3556 expectedParent = op;
3559 if (collapseCount == 0)
3565 return emitError(
"found sibling loops inside container-like acc.loop");
3566 if (collapseCount != 0)
3567 return emitError(
"failed to find enough loop-like operations inside "
3568 "container-like acc.loop");
3574unsigned LoopOp::getNumDataOperands() {
3575 return getReductionOperands().size() + getPrivateOperands().size() +
3576 getFirstprivateOperands().size();
3579Value LoopOp::getDataOperand(
unsigned i) {
3580 unsigned numOptional =
3581 getLowerbound().size() + getUpperbound().size() + getStep().size();
3582 numOptional += getGangOperands().size();
3583 numOptional += getVectorOperands().size();
3584 numOptional += getWorkerNumOperands().size();
3585 numOptional += getTileOperands().size();
3586 numOptional += getCacheOperands().size();
3587 return getOperand(numOptional + i);
3590bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3592bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3596bool LoopOp::hasIndependent() {
3597 return hasIndependent(mlir::acc::DeviceType::None);
3600bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3604bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3606bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3611 return getVectorValue(mlir::acc::DeviceType::None);
3614mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3616 getVectorOperands(), deviceType);
3619bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3621bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3626 return getWorkerValue(mlir::acc::DeviceType::None);
3629mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3631 getWorkerNumOperands(), deviceType);
3634bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3636bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3641 return getTileValues(mlir::acc::DeviceType::None);
3645LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3647 getTileOperandsSegments(), deviceType);
3650std::optional<int64_t> LoopOp::getCollapseValue() {
3651 return getCollapseValue(mlir::acc::DeviceType::None);
3654std::optional<int64_t>
3655LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3656 if (!getCollapseAttr())
3657 return std::nullopt;
3658 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3660 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3661 return intAttr.getValue().getZExtValue();
3663 return std::nullopt;
3666mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3667 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3670mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3671 mlir::acc::DeviceType deviceType) {
3672 if (getGangOperands().empty())
3674 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3675 int32_t nbOperandsBefore = 0;
3676 for (
unsigned i = 0; i < *pos; ++i)
3677 nbOperandsBefore += (*getGangOperandsSegments())[i];
3680 .drop_front(nbOperandsBefore)
3681 .take_front((*getGangOperandsSegments())[*pos]);
3683 int32_t argTypeIdx = nbOperandsBefore;
3684 for (
auto value : values) {
3685 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3686 (*getGangOperandsArgType())[argTypeIdx]);
3687 if (gangArgTypeAttr.getValue() == gangArgType)
3695bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3697bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3702 return {&getRegion()};
3746 if (!regionArgs.empty()) {
3747 p << acc::LoopOp::getControlKeyword() <<
"(";
3748 llvm::interleaveComma(regionArgs, p,
3750 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3751 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3752 <<
" : " << stepType <<
") ";
3759 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3760 effectiveDeviceTypes));
3763void acc::LoopOp::addIndependent(
3765 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3766 context, getIndependentAttr(), effectiveDeviceTypes));
3771 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3772 effectiveDeviceTypes));
3775void acc::LoopOp::setCollapseForDeviceTypes(
3777 llvm::APInt value) {
3781 assert((getCollapseAttr() ==
nullptr) ==
3782 (getCollapseDeviceTypeAttr() ==
nullptr));
3783 assert(value.getBitWidth() == 64);
3785 if (getCollapseAttr()) {
3786 for (
const auto &existing :
3787 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3788 newValues.push_back(std::get<0>(existing));
3789 newDeviceTypes.push_back(std::get<1>(existing));
3793 if (effectiveDeviceTypes.empty()) {
3796 newValues.push_back(
3797 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3798 newDeviceTypes.push_back(
3799 acc::DeviceTypeAttr::get(context, DeviceType::None));
3801 for (DeviceType dt : effectiveDeviceTypes) {
3802 newValues.push_back(
3803 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3804 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3808 setCollapseAttr(ArrayAttr::get(context, newValues));
3809 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3812void acc::LoopOp::setTileForDeviceTypes(
3816 if (getTileOperandsSegments())
3817 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3819 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3820 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3821 getTileOperandsMutable(), segments));
3823 setTileOperandsSegments(segments);
3826void acc::LoopOp::addVectorOperand(
3829 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3830 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3831 newValue, getVectorOperandsMutable()));
3834void acc::LoopOp::addEmptyVector(
3836 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3837 effectiveDeviceTypes));
3840void acc::LoopOp::addWorkerNumOperand(
3843 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3844 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3845 newValue, getWorkerNumOperandsMutable()));
3848void acc::LoopOp::addEmptyWorker(
3850 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3851 effectiveDeviceTypes));
3854void acc::LoopOp::addEmptyGang(
3856 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3857 effectiveDeviceTypes));
3860bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3861 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3862 return attr.getValue() == dt;
3864 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3865 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3868 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3870 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3872 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3878bool acc::LoopOp::hasDefaultGangWorkerVector() {
3879 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3880 hasGang() || getGangValue(GangArgType::Num) ||
3881 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3885acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3886 if (hasSeq(deviceType))
3887 return LoopParMode::loop_seq;
3888 if (hasAuto(deviceType))
3889 return LoopParMode::loop_auto;
3890 if (hasIndependent(deviceType))
3891 return LoopParMode::loop_independent;
3893 return LoopParMode::loop_seq;
3895 return LoopParMode::loop_auto;
3896 assert(hasIndependent() &&
3897 "loop must have default auto, seq, or independent");
3898 return LoopParMode::loop_independent;
3901void acc::LoopOp::addGangOperands(
3906 getGangOperandsSegments())
3907 llvm::copy(*existingSegments, std::back_inserter(segments));
3909 unsigned beforeCount = segments.size();
3911 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3912 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3913 getGangOperandsMutable(), segments));
3915 setGangOperandsSegments(segments);
3922 unsigned numAdded = segments.size() - beforeCount;
3926 if (getGangOperandsArgTypeAttr())
3927 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3929 for (
auto i : llvm::index_range(0u, numAdded)) {
3930 llvm::transform(argTypes, std::back_inserter(gangTypes),
3931 [=](mlir::acc::GangArgType gangTy) {
3932 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3937 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3941void acc::LoopOp::addPrivatization(
MLIRContext *context,
3942 mlir::acc::PrivateOp op,
3943 mlir::acc::PrivateRecipeOp recipe) {
3944 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3945 getPrivateOperandsMutable().append(op.getResult());
3948void acc::LoopOp::addFirstPrivatization(
3949 MLIRContext *context, mlir::acc::FirstprivateOp op,
3950 mlir::acc::FirstprivateRecipeOp recipe) {
3951 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3952 getFirstprivateOperandsMutable().append(op.getResult());
3955void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3956 mlir::acc::ReductionRecipeOp recipe) {
3957 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3958 getReductionOperandsMutable().append(op.getResult());
3965LogicalResult acc::DataOp::verify() {
3970 return emitError(
"at least one operand or the default attribute "
3971 "must appear on the data operation");
3973 for (
mlir::Value operand : getDataClauseOperands())
3974 if (isa<BlockArgument>(operand) ||
3975 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3976 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3977 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3978 operand.getDefiningOp()))
3979 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3988unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3990Value DataOp::getDataOperand(
unsigned i) {
3991 unsigned numOptional = getIfCond() ? 1 : 0;
3993 numOptional += getWaitOperands().size();
3994 return getOperand(numOptional + i);
3997bool acc::DataOp::hasAsyncOnly() {
3998 return hasAsyncOnly(mlir::acc::DeviceType::None);
4001bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4006 return getAsyncValue(mlir::acc::DeviceType::None);
4009mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4014bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
4016bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4021 return getWaitValues(mlir::acc::DeviceType::None);
4025DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4027 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4028 getHasWaitDevnum(), deviceType);
4032 return getWaitDevnum(mlir::acc::DeviceType::None);
4035mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4037 getWaitOperandsSegments(), getHasWaitDevnum(),
4041void acc::DataOp::addAsyncOnly(
4043 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4044 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4047void acc::DataOp::addAsyncOperand(
4050 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4051 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4052 getAsyncOperandsMutable()));
4055void acc::DataOp::addWaitOnly(
MLIRContext *context,
4057 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4058 effectiveDeviceTypes));
4061void acc::DataOp::addWaitOperands(
4066 if (getWaitOperandsSegments())
4067 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4069 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4070 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4071 getWaitOperandsMutable(), segments));
4072 setWaitOperandsSegments(segments);
4075 if (getHasWaitDevnumAttr())
4076 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4079 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4081 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4088LogicalResult acc::ExitDataOp::verify() {
4092 if (getDataClauseOperands().empty())
4093 return emitError(
"at least one operand must be present in dataOperands on "
4094 "the exit data operation");
4098 if (getAsyncOperand() && getAsync())
4099 return emitError(
"async attribute cannot appear with asyncOperand");
4103 if (!getWaitOperands().empty() && getWait())
4104 return emitError(
"wait attribute cannot appear with waitOperands");
4106 if (getWaitDevnum() && getWaitOperands().empty())
4107 return emitError(
"wait_devnum cannot appear without waitOperands");
4112unsigned ExitDataOp::getNumDataOperands() {
4113 return getDataClauseOperands().size();
4116Value ExitDataOp::getDataOperand(
unsigned i) {
4117 unsigned numOptional = getIfCond() ? 1 : 0;
4118 numOptional += getAsyncOperand() ? 1 : 0;
4119 numOptional += getWaitDevnum() ? 1 : 0;
4120 return getOperand(getWaitOperands().size() + numOptional + i);
4125 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
4128void ExitDataOp::addAsyncOnly(
MLIRContext *context,
4130 assert(effectiveDeviceTypes.empty());
4131 assert(!getAsyncAttr());
4132 assert(!getAsyncOperand());
4134 setAsyncAttr(mlir::UnitAttr::get(context));
4137void ExitDataOp::addAsyncOperand(
4140 assert(effectiveDeviceTypes.empty());
4141 assert(!getAsyncAttr());
4142 assert(!getAsyncOperand());
4144 getAsyncOperandMutable().append(newValue);
4149 assert(effectiveDeviceTypes.empty());
4150 assert(!getWaitAttr());
4151 assert(getWaitOperands().empty());
4152 assert(!getWaitDevnum());
4154 setWaitAttr(mlir::UnitAttr::get(context));
4157void ExitDataOp::addWaitOperands(
4160 assert(effectiveDeviceTypes.empty());
4161 assert(!getWaitAttr());
4162 assert(getWaitOperands().empty());
4163 assert(!getWaitDevnum());
4168 getWaitDevnumMutable().append(newValues.front());
4169 newValues = newValues.drop_front();
4172 getWaitOperandsMutable().append(newValues);
4179LogicalResult acc::EnterDataOp::verify() {
4183 if (getDataClauseOperands().empty())
4184 return emitError(
"at least one operand must be present in dataOperands on "
4185 "the enter data operation");
4189 if (getAsyncOperand() && getAsync())
4190 return emitError(
"async attribute cannot appear with asyncOperand");
4194 if (!getWaitOperands().empty() && getWait())
4195 return emitError(
"wait attribute cannot appear with waitOperands");
4197 if (getWaitDevnum() && getWaitOperands().empty())
4198 return emitError(
"wait_devnum cannot appear without waitOperands");
4200 for (
mlir::Value operand : getDataClauseOperands())
4201 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4202 operand.getDefiningOp()))
4203 return emitError(
"expect data entry operation as defining op");
4208unsigned EnterDataOp::getNumDataOperands() {
4209 return getDataClauseOperands().size();
4212Value EnterDataOp::getDataOperand(
unsigned i) {
4213 unsigned numOptional = getIfCond() ? 1 : 0;
4214 numOptional += getAsyncOperand() ? 1 : 0;
4215 numOptional += getWaitDevnum() ? 1 : 0;
4216 return getOperand(getWaitOperands().size() + numOptional + i);
4221 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
4224void EnterDataOp::addAsyncOnly(
4226 assert(effectiveDeviceTypes.empty());
4227 assert(!getAsyncAttr());
4228 assert(!getAsyncOperand());
4230 setAsyncAttr(mlir::UnitAttr::get(context));
4233void EnterDataOp::addAsyncOperand(
4236 assert(effectiveDeviceTypes.empty());
4237 assert(!getAsyncAttr());
4238 assert(!getAsyncOperand());
4240 getAsyncOperandMutable().append(newValue);
4243void EnterDataOp::addWaitOnly(
MLIRContext *context,
4245 assert(effectiveDeviceTypes.empty());
4246 assert(!getWaitAttr());
4247 assert(getWaitOperands().empty());
4248 assert(!getWaitDevnum());
4250 setWaitAttr(mlir::UnitAttr::get(context));
4253void EnterDataOp::addWaitOperands(
4256 assert(effectiveDeviceTypes.empty());
4257 assert(!getWaitAttr());
4258 assert(getWaitOperands().empty());
4259 assert(!getWaitDevnum());
4264 getWaitDevnumMutable().append(newValues.front());
4265 newValues = newValues.drop_front();
4268 getWaitOperandsMutable().append(newValues);
4275LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
4281LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
4287LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4294 if (
Value writeVal = op.getWriteOpVal()) {
4303LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
4305LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4311AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4312 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4314 return dyn_cast<AtomicReadOp>(getSecondOp());
4317AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4318 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4320 return dyn_cast<AtomicWriteOp>(getSecondOp());
4323AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4324 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4326 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4329LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
4335template <
typename Op>
4338 bool requireAtLeastOneOperand =
true) {
4339 if (operands.empty() && requireAtLeastOneOperand)
4342 "at least one operand must appear on the declare operation");
4345 if (isa<BlockArgument>(operand) ||
4346 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4347 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4348 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4349 operand.getDefiningOp()))
4351 "expect valid declare data entry operation or acc.getdeviceptr "
4355 assert(var &&
"declare operands can only be data entry operations which "
4358 std::optional<mlir::acc::DataClause> dataClauseOptional{
4360 assert(dataClauseOptional.has_value() &&
4361 "declare operands can only be data entry operations which must have "
4363 (
void)dataClauseOptional;
4369LogicalResult acc::DeclareEnterOp::verify() {
4377LogicalResult acc::DeclareExitOp::verify() {
4388LogicalResult acc::DeclareOp::verify() {
4397 acc::DeviceType dtype) {
4398 unsigned parallelism = 0;
4399 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4400 parallelism += op.hasWorker(dtype) ? 1 : 0;
4401 parallelism += op.hasVector(dtype) ? 1 : 0;
4402 parallelism += op.hasSeq(dtype) ? 1 : 0;
4406LogicalResult acc::RoutineOp::verify() {
4407 unsigned baseParallelism =
4410 if (baseParallelism > 1)
4411 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4412 "be present at the same time";
4414 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4416 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4417 if (dtype == acc::DeviceType::None)
4421 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4422 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4423 "be present at the same time for device_type `"
4424 << acc::stringifyDeviceType(dtype) <<
"`";
4431 mlir::ArrayAttr &bindIdName,
4432 mlir::ArrayAttr &bindStrName,
4433 mlir::ArrayAttr &deviceIdTypes,
4434 mlir::ArrayAttr &deviceStrTypes) {
4441 mlir::Attribute newAttr;
4442 bool isSymbolRefAttr;
4443 auto parseResult = parser.parseAttribute(newAttr);
4444 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4445 bindIdNameAttrs.push_back(symbolRefAttr);
4446 isSymbolRefAttr = true;
4447 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4448 bindStrNameAttrs.push_back(stringAttr);
4449 isSymbolRefAttr =
false;
4454 if (isSymbolRefAttr) {
4455 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4456 parser.getContext(), mlir::acc::DeviceType::None));
4458 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4459 parser.getContext(), mlir::acc::DeviceType::None));
4462 if (isSymbolRefAttr) {
4463 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4464 parser.parseRSquare())
4467 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4468 parser.parseRSquare())
4476 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4477 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4478 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4479 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4485 std::optional<mlir::ArrayAttr> bindIdName,
4486 std::optional<mlir::ArrayAttr> bindStrName,
4487 std::optional<mlir::ArrayAttr> deviceIdTypes,
4488 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4495 allBindNames.append(bindIdName->begin(), bindIdName->end());
4496 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4501 allBindNames.append(bindStrName->begin(), bindStrName->end());
4502 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4506 if (!allBindNames.empty())
4507 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4508 [&](
const auto &pair) {
4509 p << std::get<0>(pair);
4515 mlir::ArrayAttr &gang,
4516 mlir::ArrayAttr &gangDim,
4517 mlir::ArrayAttr &gangDimDeviceTypes) {
4520 gangDimDeviceTypeAttrs;
4521 bool needCommaBeforeOperands =
false;
4525 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4526 parser.
getContext(), mlir::acc::DeviceType::None));
4527 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4534 if (parser.parseAttribute(gangAttrs.emplace_back()))
4541 needCommaBeforeOperands =
true;
4544 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4548 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4549 parser.parseColon() ||
4550 parser.parseAttribute(gangDimAttrs.emplace_back()))
4552 if (succeeded(parser.parseOptionalLSquare())) {
4553 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4554 parser.parseRSquare())
4557 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4558 parser.getContext(), mlir::acc::DeviceType::None));
4564 if (
failed(parser.parseRParen()))
4567 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4568 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4569 gangDimDeviceTypes =
4570 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4576 std::optional<mlir::ArrayAttr> gang,
4577 std::optional<mlir::ArrayAttr> gangDim,
4578 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4581 gang->size() == 1) {
4582 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4583 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4595 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4596 [&](
const auto &pair) {
4597 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4598 p << std::get<0>(pair);
4606 mlir::ArrayAttr &deviceTypes) {
4610 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4611 parser.
getContext(), mlir::acc::DeviceType::None));
4612 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4619 if (parser.parseAttribute(attributes.emplace_back()))
4627 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4633 std::optional<mlir::ArrayAttr> deviceTypes) {
4636 auto deviceTypeAttr =
4637 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4638 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4647 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4653bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4655bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4659bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4661bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4665bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4667bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4671std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4672RoutineOp::getBindNameValue() {
4673 return getBindNameValue(mlir::acc::DeviceType::None);
4676std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4677RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4680 return std::nullopt;
4683 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4684 auto attr = (*getBindIdName())[*pos];
4685 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4686 assert(symbolRefAttr &&
"expected SymbolRef");
4687 return symbolRefAttr;
4690 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4691 auto attr = (*getBindStrName())[*pos];
4692 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4693 assert(stringAttr &&
"expected String");
4697 return std::nullopt;
4700bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4702bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4706std::optional<int64_t> RoutineOp::getGangDimValue() {
4707 return getGangDimValue(mlir::acc::DeviceType::None);
4710std::optional<int64_t>
4711RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4713 return std::nullopt;
4714 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4715 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4716 return intAttr.getInt();
4718 return std::nullopt;
4723 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4724 effectiveDeviceTypes));
4729 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4730 effectiveDeviceTypes));
4735 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4736 effectiveDeviceTypes));
4741 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4742 effectiveDeviceTypes));
4751 if (getGangDimAttr())
4752 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4753 if (getGangDimDeviceTypeAttr())
4754 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4756 assert(dimValues.size() == deviceTypes.size());
4758 if (effectiveDeviceTypes.empty()) {
4759 dimValues.push_back(
4760 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4761 deviceTypes.push_back(
4762 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4764 for (DeviceType dt : effectiveDeviceTypes) {
4765 dimValues.push_back(
4766 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4767 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4770 assert(dimValues.size() == deviceTypes.size());
4772 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4773 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4776void RoutineOp::addBindStrName(
MLIRContext *context,
4778 mlir::StringAttr val) {
4779 unsigned before = getBindStrNameDeviceTypeAttr()
4780 ? getBindStrNameDeviceTypeAttr().size()
4783 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4784 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4785 unsigned after = getBindStrNameDeviceTypeAttr().size();
4788 if (getBindStrNameAttr())
4789 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4790 for (
unsigned i = 0; i < after - before; ++i)
4791 vals.push_back(val);
4793 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4796void RoutineOp::addBindIDName(
MLIRContext *context,
4798 mlir::SymbolRefAttr val) {
4800 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4802 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4803 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4804 unsigned after = getBindIdNameDeviceTypeAttr().size();
4807 if (getBindIdNameAttr())
4808 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4809 for (
unsigned i = 0; i < after - before; ++i)
4810 vals.push_back(val);
4812 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4819LogicalResult acc::InitOp::verify() {
4823 return emitOpError(
"cannot be nested in a compute operation");
4827void acc::InitOp::addDeviceType(
MLIRContext *context,
4828 mlir::acc::DeviceType deviceType) {
4830 if (getDeviceTypesAttr())
4831 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4833 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4834 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4841LogicalResult acc::ShutdownOp::verify() {
4845 return emitOpError(
"cannot be nested in a compute operation");
4849void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4850 mlir::acc::DeviceType deviceType) {
4852 if (getDeviceTypesAttr())
4853 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4855 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4856 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4863LogicalResult acc::SetOp::verify() {
4867 return emitOpError(
"cannot be nested in a compute operation");
4868 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4869 return emitOpError(
"at least one default_async, device_num, or device_type "
4870 "operand must appear");
4878LogicalResult acc::UpdateOp::verify() {
4880 if (getDataClauseOperands().empty())
4881 return emitError(
"at least one value must be present in dataOperands");
4884 getAsyncOperandsDeviceTypeAttr(),
4889 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4890 getWaitOperandsDeviceTypeAttr(),
"wait")))
4896 for (
mlir::Value operand : getDataClauseOperands())
4897 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4898 operand.getDefiningOp()))
4899 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4905unsigned UpdateOp::getNumDataOperands() {
4906 return getDataClauseOperands().size();
4909Value UpdateOp::getDataOperand(
unsigned i) {
4911 numOptional += getIfCond() ? 1 : 0;
4912 return getOperand(getWaitOperands().size() + numOptional + i);
4917 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
4920bool UpdateOp::hasAsyncOnly() {
4921 return hasAsyncOnly(mlir::acc::DeviceType::None);
4924bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4929 return getAsyncValue(mlir::acc::DeviceType::None);
4932mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4942bool UpdateOp::hasWaitOnly() {
4943 return hasWaitOnly(mlir::acc::DeviceType::None);
4946bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4951 return getWaitValues(mlir::acc::DeviceType::None);
4955UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4957 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4958 getHasWaitDevnum(), deviceType);
4962 return getWaitDevnum(mlir::acc::DeviceType::None);
4965mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4967 getWaitOperandsSegments(), getHasWaitDevnum(),
4973 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4974 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4977void UpdateOp::addAsyncOperand(
4980 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4981 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4982 getAsyncOperandsMutable()));
4987 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4988 effectiveDeviceTypes));
4991void UpdateOp::addWaitOperands(
4996 if (getWaitOperandsSegments())
4997 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4999 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5000 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
5001 getWaitOperandsMutable(), segments));
5002 setWaitOperandsSegments(segments);
5005 if (getHasWaitDevnumAttr())
5006 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5009 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
5011 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5018LogicalResult acc::WaitOp::verify() {
5021 if (getAsyncOperand() && getAsync())
5022 return emitError(
"async attribute cannot appear with asyncOperand");
5024 if (getWaitDevnum() && getWaitOperands().empty())
5025 return emitError(
"wait_devnum cannot appear without waitOperands");
5030#define GET_OP_CLASSES
5031#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5033#define GET_ATTRDEF_CLASSES
5034#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5036#define GET_TYPEDEF_CLASSES
5037#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5048 .Case<ACC_DATA_ENTRY_OPS>(
5049 [&](
auto entry) {
return entry.getVarPtr(); })
5050 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5051 [&](
auto exit) {
return exit.getVarPtr(); })
5069 [&](
auto entry) {
return entry.getVarType(); })
5070 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5071 [&](
auto exit) {
return exit.getVarType(); })
5081 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5082 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
5092 [&](
auto dataClause) {
return dataClause.getAccVar(); })
5101 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
5111 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5113 dataClause.getBounds().begin(), dataClause.getBounds().end());
5125 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5127 dataClause.getAsyncOperands().begin(),
5128 dataClause.getAsyncOperands().end());
5139 return dataClause.getAsyncOperandsDeviceTypeAttr();
5147 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
5154 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
5161std::optional<mlir::acc::DataClause>
5166 .Case<ACC_DATA_ENTRY_OPS>(
5167 [&](
auto entry) {
return entry.getDataClause(); })
5175 [&](
auto entry) {
return entry.getImplicit(); })
5184 [&](
auto entry) {
return entry.getDataClauseOperands(); })
5186 return dataOperands;
5194 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
5196 return dataOperands;
5203 [&](
auto entry) {
return entry.getRecipeAttr(); })
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
static bool isComputeOperation(Operation *op)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
unsigned size() const
Returns the current size of the range.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static CurrentDeviceIdResource * get()
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.