25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
41static bool isScalarLikeType(
Type type) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
56struct MemRefPointerLikeModel
57 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
59 return cast<T>(pointer).getElementType();
62 mlir::acc::VariableTypeCategory
65 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
72 return mlir::acc::VariableTypeCategory::uncategorized;
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
81 return mlir::acc::VariableTypeCategory::uncategorized;
85 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree)
const {
92 auto memrefTy = cast<MemRefType>(pointer);
96 if (memrefTy.hasStaticShape()) {
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
105 if (originalVar && originalVar.
getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
132 Type varType)
const {
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc =
nullptr;
143 while (currentValue) {
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
152 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
158 if (
auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
172 if (isa<memref::AllocaOp>(originalAlloc)) {
176 if (isa<memref::AllocOp>(originalAlloc)) {
178 memref::DeallocOp::create(builder, loc, memrefValue);
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
209 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType)
const {
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
220 auto memrefTy = memrefValue.
getType();
223 if (memrefTy.getRank() != 0)
226 return memref::LoadOp::create(builder, loc, memrefValue);
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
239 auto memrefTy = memrefValue.getType();
242 if (memrefTy.getRank() != 0)
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
249 bool isDeviceData(Type pointer, Value var)
const {
250 auto memrefTy = cast<T>(pointer);
251 Attribute memSpace = memrefTy.getMemorySpace();
252 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
256struct LLVMPointerPointerLikeModel
257 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
258 LLVM::LLVMPointerType> {
261 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
263 Type valueType)
const {
268 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
271 bool genStore(Type pointer, OpBuilder &builder, Location loc,
273 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
278struct MemrefAddressOfGlobalModel
279 :
public AddressOfGlobalOpInterface::ExternalModel<
280 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
281 SymbolRefAttr getSymbol(Operation *op)
const {
282 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
283 return getGlobalOp.getNameAttr();
287struct MemrefGlobalVariableModel
288 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
290 bool isConstant(Operation *op)
const {
291 auto globalOp = cast<memref::GlobalOp>(op);
292 return globalOp.getConstant();
295 Region *getInitRegion(Operation *op)
const {
300 bool isDeviceData(Operation *op)
const {
301 auto globalOp = cast<memref::GlobalOp>(op);
302 Attribute memSpace = globalOp.getType().getMemorySpace();
303 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
307struct GPULaunchOffloadRegionModel
308 :
public acc::OffloadRegionOpInterface::ExternalModel<
309 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
310 mlir::Region &getOffloadRegion(mlir::Operation *op)
const {
311 return cast<gpu::LaunchOp>(op).getBody();
319mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
320 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
323 if (existingDeviceTypes)
324 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
326 if (newDeviceTypes.empty())
327 deviceTypes.push_back(
328 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
330 for (DeviceType dt : newDeviceTypes)
331 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
333 return mlir::ArrayAttr::get(context, deviceTypes);
342mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
343 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
348 if (existingDeviceTypes)
349 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
351 if (newDeviceTypes.empty()) {
352 argCollection.
append(arguments);
353 segments.push_back(arguments.size());
354 deviceTypes.push_back(
355 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
358 for (DeviceType dt : newDeviceTypes) {
359 argCollection.
append(arguments);
360 segments.push_back(arguments.size());
361 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
364 return mlir::ArrayAttr::get(context, deviceTypes);
368mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
369 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
373 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
374 newDeviceTypes, arguments,
375 argCollection, segments);
383void OpenACCDialect::initialize() {
386#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
393#define GET_TYPEDEF_LIST
394#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
400 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
402 UnrankedMemRefType::attachInterface<
403 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
404 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
408 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
410 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
411 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*
getContext());
448void ParallelOp::getSuccessorRegions(
468void KernelEnvironmentOp::getSuccessorRegions(
488void HostDataOp::getSuccessorRegions(
503 if (getUnstructured()) {
536 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
540 mlir::acc::DeviceType deviceType) {
544 for (
auto attr : *arrayAttr) {
545 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
546 if (deviceTypeAttr.getValue() == deviceType)
554 std::optional<mlir::ArrayAttr> deviceTypes) {
559 llvm::interleaveComma(*deviceTypes, p,
565 mlir::acc::DeviceType deviceType) {
566 unsigned segmentIdx = 0;
567 for (
auto attr : segments) {
568 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
569 if (deviceTypeAttr.getValue() == deviceType)
570 return std::make_optional(segmentIdx);
580 mlir::acc::DeviceType deviceType) {
582 return range.take_front(0);
583 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
584 int32_t nbOperandsBefore = 0;
585 for (
unsigned i = 0; i < *pos; ++i)
586 nbOperandsBefore += (*segments)[i];
587 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
589 return range.take_front(0);
596 std::optional<mlir::ArrayAttr> hasWaitDevnum,
597 mlir::acc::DeviceType deviceType) {
600 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
601 if (hasWaitDevnum->getValue()[*pos])
612 std::optional<mlir::ArrayAttr> hasWaitDevnum,
613 mlir::acc::DeviceType deviceType) {
618 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
619 if (hasWaitDevnum && *hasWaitDevnum) {
620 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
621 if (boolAttr.getValue())
622 return range.drop_front(1);
628template <
typename Op>
630 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
632 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
637 op.hasAsyncOnly(dtype))
639 "asyncOnly attribute cannot appear with asyncOperand");
644 op.hasWaitOnly(dtype))
645 return op.
emitError(
"wait attribute cannot appear with waitOperands");
650template <
typename Op>
653 return op.
emitError(
"must have var operand");
656 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
657 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
658 return op.
emitError(
"var must be mappable or pointer-like");
661 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
662 op.getVarType() == op.getVar().getType())
663 return op.
emitError(
"varType must capture the element type of var");
668template <
typename Op>
670 if (op.getVar().getType() != op.getAccVar().getType())
671 return op.
emitError(
"input and output types must match");
676template <
typename Op>
678 if (op.getModifiers() != acc::DataClauseModifier::none)
679 return op.
emitError(
"no data clause modifiers are allowed");
683template <
typename Op>
686 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
688 "invalid data clause modifiers: " +
689 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
694template <
typename OpT,
typename RecipeOpT>
695static LogicalResult
checkRecipe(OpT op, llvm::StringRef operandName) {
700 !std::is_same_v<OpT, acc::ReductionOp>)
703 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
705 return op->emitOpError() <<
"recipe expected for " << operandName;
710 return op->emitOpError()
711 <<
"expected symbol reference " << operandRecipe <<
" to point to a "
712 << operandName <<
" declaration";
733 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
754 if (failed(parser.
parseType(accVarType)))
764 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
776 mlir::TypeAttr &varTypeAttr) {
777 if (failed(parser.
parseType(varPtrType)))
788 varTypeAttr = mlir::TypeAttr::get(varType);
793 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
794 varTypeAttr = mlir::TypeAttr::get(
795 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
797 varTypeAttr = mlir::TypeAttr::get(varPtrType);
804 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
812 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
813 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
815 if (typeToCheckAgainst != varType) {
823 mlir::SymbolRefAttr &recipeAttr) {
830 mlir::SymbolRefAttr recipeAttr) {
837LogicalResult acc::DataBoundsOp::verify() {
838 auto extent = getExtent();
839 auto upperbound = getUpperbound();
840 if (!extent && !upperbound)
841 return emitError(
"expected extent or upperbound.");
848LogicalResult acc::PrivateOp::verify() {
851 "data clause associated with private operation must match its intent");
865LogicalResult acc::FirstprivateOp::verify() {
867 return emitError(
"data clause associated with firstprivate operation must "
874 *
this,
"firstprivate")))
882LogicalResult acc::FirstprivateMapInitialOp::verify() {
884 return emitError(
"data clause associated with firstprivate operation must "
896LogicalResult acc::ReductionOp::verify() {
898 return emitError(
"data clause associated with reduction operation must "
905 *
this,
"reduction")))
913LogicalResult acc::DevicePtrOp::verify() {
915 return emitError(
"data clause associated with deviceptr operation must "
929LogicalResult acc::PresentOp::verify() {
932 "data clause associated with present operation must match its intent");
945LogicalResult acc::CopyinOp::verify() {
947 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
952 "data clause associated with copyin operation must match its intent"
953 " or specify original clause this operation was decomposed from");
959 acc::DataClauseModifier::always |
960 acc::DataClauseModifier::capture)))
965bool acc::CopyinOp::isCopyinReadonly() {
966 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
967 acc::bitEnumContainsAny(getModifiers(),
968 acc::DataClauseModifier::readonly);
974LogicalResult acc::CreateOp::verify() {
981 "data clause associated with create operation must match its intent"
982 " or specify original clause this operation was decomposed from");
990 acc::DataClauseModifier::always |
991 acc::DataClauseModifier::capture)))
996bool acc::CreateOp::isCreateZero() {
998 return getDataClause() == acc::DataClause::acc_create_zero ||
1000 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1006LogicalResult acc::NoCreateOp::verify() {
1008 return emitError(
"data clause associated with no_create operation must "
1009 "match its intent");
1022LogicalResult acc::AttachOp::verify() {
1025 "data clause associated with attach operation must match its intent");
1039LogicalResult acc::DeclareDeviceResidentOp::verify() {
1040 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
1041 return emitError(
"data clause associated with device_resident operation "
1042 "must match its intent");
1056LogicalResult acc::DeclareLinkOp::verify() {
1059 "data clause associated with link operation must match its intent");
1072LogicalResult acc::CopyoutOp::verify() {
1079 "data clause associated with copyout operation must match its intent"
1080 " or specify original clause this operation was decomposed from");
1082 return emitError(
"must have both host and device pointers");
1088 acc::DataClauseModifier::always |
1089 acc::DataClauseModifier::capture)))
1094bool acc::CopyoutOp::isCopyoutZero() {
1095 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1096 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1102LogicalResult acc::DeleteOp::verify() {
1111 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1114 "data clause associated with delete operation must match its intent"
1115 " or specify original clause this operation was decomposed from");
1117 return emitError(
"must have device pointer");
1121 acc::DataClauseModifier::readonly |
1122 acc::DataClauseModifier::always |
1123 acc::DataClauseModifier::capture)))
1131LogicalResult acc::DetachOp::verify() {
1136 "data clause associated with detach operation must match its intent"
1137 " or specify original clause this operation was decomposed from");
1139 return emitError(
"must have device pointer");
1148LogicalResult acc::UpdateHostOp::verify() {
1153 "data clause associated with host operation must match its intent"
1154 " or specify original clause this operation was decomposed from");
1156 return emitError(
"must have both host and device pointers");
1169LogicalResult acc::UpdateDeviceOp::verify() {
1173 "data clause associated with device operation must match its intent"
1174 " or specify original clause this operation was decomposed from");
1187LogicalResult acc::UseDeviceOp::verify() {
1191 "data clause associated with use_device operation must match its intent"
1192 " or specify original clause this operation was decomposed from");
1205LogicalResult acc::CacheOp::verify() {
1210 "data clause associated with cache operation must match its intent"
1211 " or specify original clause this operation was decomposed from");
1221bool acc::CacheOp::isCacheReadonly() {
1222 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1223 acc::bitEnumContainsAny(getModifiers(),
1224 acc::DataClauseModifier::readonly);
1238 if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
1246template <
typename EffectTy>
1251 for (
unsigned i = 0, e = operand.
size(); i < e; ++i)
1252 effects.emplace_back(EffectTy::get(), &operand[i]);
1256template <
typename EffectTy>
1261 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(
result));
1265void acc::PrivateOp::getEffects(
1279void acc::FirstprivateOp::getEffects(
1293void acc::FirstprivateMapInitialOp::getEffects(
1303void acc::ReductionOp::getEffects(
1317void acc::DevicePtrOp::getEffects(
1326void acc::PresentOp::getEffects(
1337void acc::CopyinOp::getEffects(
1350void acc::CreateOp::getEffects(
1363void acc::NoCreateOp::getEffects(
1374void acc::AttachOp::getEffects(
1387void acc::GetDevicePtrOp::getEffects(
1396void acc::UpdateDeviceOp::getEffects(
1406void acc::UseDeviceOp::getEffects(
1415void acc::DeclareDeviceResidentOp::getEffects(
1426void acc::DeclareLinkOp::getEffects(
1437void acc::CacheOp::getEffects(
1442void acc::CopyoutOp::getEffects(
1455void acc::DeleteOp::getEffects(
1467void acc::DetachOp::getEffects(
1479void acc::UpdateHostOp::getEffects(
1491template <
typename StructureOp>
1493 unsigned nRegions = 1) {
1496 for (
unsigned i = 0; i < nRegions; ++i)
1499 for (
Region *region : regions)
1507 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1514template <
typename OpTy>
1516 using OpRewritePattern<OpTy>::OpRewritePattern;
1518 LogicalResult matchAndRewrite(OpTy op,
1519 PatternRewriter &rewriter)
const override {
1521 Value ifCond = op.getIfCond();
1525 IntegerAttr constAttr;
1528 if (constAttr.getInt())
1529 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1541 assert(region.
hasOneBlock() &&
"expected single-block region");
1553template <
typename OpTy>
1554struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1555 using OpRewritePattern<OpTy>::OpRewritePattern;
1557 LogicalResult matchAndRewrite(OpTy op,
1558 PatternRewriter &rewriter)
const override {
1560 Value ifCond = op.getIfCond();
1564 IntegerAttr constAttr;
1567 if (constAttr.getInt())
1568 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1578struct RemoveEmptyKernelEnvironment
1580 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1582 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1583 PatternRewriter &rewriter)
const override {
1584 assert(op->getNumRegions() == 1 &&
"expected op to have one region");
1595 if (
auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1596 for (
auto attr : deviceTypeAttr) {
1597 if (
auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1598 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1605 if (
auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1606 for (
auto attr : hasDevnumAttr) {
1607 if (
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1608 if (boolAttr.getValue())
1615 if (
auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1616 if (segmentsAttr.size() > 1)
1622 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1649 for (
Value bound : bounds) {
1650 argTypes.push_back(bound.getType());
1651 argLocs.push_back(loc);
1658 Value privatizedValue;
1664 if (isa<MappableType>(varType)) {
1665 auto mappableTy = cast<MappableType>(varType);
1666 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1667 privatizedValue = mappableTy.generatePrivateInit(
1668 builder, loc, typedVar, varName, bounds, {}, needsFree);
1669 if (!privatizedValue)
1672 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1673 auto pointerLikeTy = cast<PointerLikeType>(varType);
1675 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1676 blockArgVar, needsFree);
1677 if (!privatizedValue)
1682 acc::YieldOp::create(builder, loc, privatizedValue);
1697 for (
Value bound : bounds) {
1698 copyArgTypes.push_back(bound.getType());
1699 copyArgLocs.push_back(loc);
1706 bool isMappable = isa<MappableType>(varType);
1707 bool isPointerLike = isa<PointerLikeType>(varType);
1710 if (isMappable && !isPointerLike)
1714 if (isPointerLike) {
1715 auto pointerLikeTy = cast<PointerLikeType>(varType);
1720 if (!pointerLikeTy.genCopy(
1727 acc::TerminatorOp::create(builder, loc);
1741 for (
Value bound : bounds) {
1742 destroyArgTypes.push_back(bound.getType());
1743 destroyArgLocs.push_back(loc);
1747 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1751 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1752 if (isa<MappableType>(varType)) {
1753 auto mappableTy = cast<MappableType>(varType);
1754 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds))
1757 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1758 auto pointerLikeTy = cast<PointerLikeType>(varType);
1759 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1763 acc::TerminatorOp::create(builder, loc);
1774 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1776 if (optional && region.
empty())
1780 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1784 return op->
emitOpError() <<
"expects " << regionName
1787 << regionType <<
" type";
1790 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1791 if (yieldOp.getOperands().size() != 1 ||
1792 yieldOp.getOperands().getTypes()[0] != type)
1793 return op->
emitOpError() <<
"expects " << regionName
1795 "yield a value of the "
1796 << regionType <<
" type";
1802LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1804 "privatization",
"init",
getType(),
1808 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1814std::optional<PrivateRecipeOp>
1816 StringRef recipeName,
Type varType,
1819 bool isMappable = isa<MappableType>(varType);
1820 bool isPointerLike = isa<PointerLikeType>(varType);
1823 if (!isMappable && !isPointerLike)
1824 return std::nullopt;
1829 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1832 bool needsFree =
false;
1833 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1834 varName, bounds, needsFree))) {
1836 return std::nullopt;
1843 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1844 Value allocRes = yieldOp.getOperand(0);
1846 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1847 varType, allocRes, bounds))) {
1849 return std::nullopt;
1856std::optional<PrivateRecipeOp>
1858 StringRef recipeName,
1859 FirstprivateRecipeOp firstprivRecipe) {
1862 auto varType = firstprivRecipe.getType();
1863 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1867 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1870 if (!firstprivRecipe.getDestroyRegion().empty()) {
1872 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1882LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1884 "privatization",
"init",
getType(),
1888 if (getCopyRegion().empty())
1889 return emitOpError() <<
"expects non-empty copy region";
1894 return emitOpError() <<
"expects copy region with two arguments of the "
1895 "privatization type";
1897 if (getDestroyRegion().empty())
1901 "privatization",
"destroy",
1908std::optional<FirstprivateRecipeOp>
1910 StringRef recipeName,
Type varType,
1913 bool isMappable = isa<MappableType>(varType);
1914 bool isPointerLike = isa<PointerLikeType>(varType);
1917 if (!isMappable && !isPointerLike)
1918 return std::nullopt;
1923 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1926 bool needsFree =
false;
1927 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1928 varName, bounds, needsFree))) {
1930 return std::nullopt;
1934 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1937 return std::nullopt;
1944 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1945 Value allocRes = yieldOp.getOperand(0);
1947 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1948 varType, allocRes, bounds))) {
1950 return std::nullopt;
1961LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1967 if (getCombinerRegion().empty())
1968 return emitOpError() <<
"expects non-empty combiner region";
1970 Block &reductionBlock = getCombinerRegion().
front();
1974 return emitOpError() <<
"expects combiner region with the first two "
1975 <<
"arguments of the reduction type";
1977 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1978 if (yieldOp.getOperands().size() != 1 ||
1979 yieldOp.getOperands().getTypes()[0] !=
getType())
1980 return emitOpError() <<
"expects combiner region to yield a value "
1981 "of the reduction type";
1992template <
typename Op>
1996 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1997 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1998 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1999 operand.getDefiningOp()))
2001 "expect data entry/exit operation or acc.getdeviceptr "
2006template <
typename OpT,
typename RecipeOpT>
2009 llvm::StringRef operandName) {
2012 if (!mlir::isa<OpT>(operand.getDefiningOp()))
2014 <<
"expected " << operandName <<
" as defining op";
2015 if (!set.insert(operand).second)
2017 << operandName <<
" operand appears more than once";
2022unsigned ParallelOp::getNumDataOperands() {
2023 return getReductionOperands().size() + getPrivateOperands().size() +
2024 getFirstprivateOperands().size() + getDataClauseOperands().size();
2027Value ParallelOp::getDataOperand(
unsigned i) {
2029 numOptional += getNumGangs().size();
2030 numOptional += getNumWorkers().size();
2031 numOptional += getVectorLength().size();
2032 numOptional += getIfCond() ? 1 : 0;
2033 numOptional += getSelfCond() ? 1 : 0;
2034 return getOperand(getWaitOperands().size() + numOptional + i);
2037template <
typename Op>
2040 llvm::StringRef keyword) {
2041 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
2042 return op.
emitOpError() << keyword <<
" operands count must match "
2043 << keyword <<
" device_type count";
2047template <
typename Op>
2050 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
2051 std::size_t numOperandsInSegments = 0;
2052 std::size_t nbOfSegments = 0;
2055 for (
auto segCount : segments.
asArrayRef()) {
2056 if (maxInSegment != 0 && segCount > maxInSegment)
2057 return op.
emitOpError() << keyword <<
" expects a maximum of "
2058 << maxInSegment <<
" values per segment";
2059 numOperandsInSegments += segCount;
2064 if ((numOperandsInSegments != operands.size()) ||
2065 (!deviceTypes && !operands.empty()))
2067 << keyword <<
" operand count does not match count in segments";
2068 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
2070 << keyword <<
" segment count does not match device_type count";
2074LogicalResult acc::ParallelOp::verify() {
2076 mlir::acc::PrivateRecipeOp>(
2077 *
this, getPrivateOperands(),
"private")))
2080 mlir::acc::FirstprivateRecipeOp>(
2081 *
this, getFirstprivateOperands(),
"firstprivate")))
2084 mlir::acc::ReductionRecipeOp>(
2085 *
this, getReductionOperands(),
"reduction")))
2089 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2090 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2094 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2095 getWaitOperandsDeviceTypeAttr(),
"wait")))
2099 getNumWorkersDeviceTypeAttr(),
2104 getVectorLengthDeviceTypeAttr(),
2109 getAsyncOperandsDeviceTypeAttr(),
2122 mlir::acc::DeviceType deviceType) {
2125 if (
auto pos =
findSegment(*arrayAttr, deviceType))
2130bool acc::ParallelOp::hasAsyncOnly() {
2131 return hasAsyncOnly(mlir::acc::DeviceType::None);
2134bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2139 return getAsyncValue(mlir::acc::DeviceType::None);
2142mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2147mlir::Value acc::ParallelOp::getNumWorkersValue() {
2148 return getNumWorkersValue(mlir::acc::DeviceType::None);
2152acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2157mlir::Value acc::ParallelOp::getVectorLengthValue() {
2158 return getVectorLengthValue(mlir::acc::DeviceType::None);
2162acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2164 getVectorLength(), deviceType);
2168 return getNumGangsValues(mlir::acc::DeviceType::None);
2172ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2174 getNumGangsSegments(), deviceType);
2177bool acc::ParallelOp::hasWaitOnly() {
2178 return hasWaitOnly(mlir::acc::DeviceType::None);
2181bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2186 return getWaitValues(mlir::acc::DeviceType::None);
2190ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2192 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2193 getHasWaitDevnum(), deviceType);
2197 return getWaitDevnum(mlir::acc::DeviceType::None);
2200mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2202 getWaitOperandsSegments(), getHasWaitDevnum(),
2217 odsBuilder, odsState, asyncOperands,
nullptr,
2218 nullptr, waitOperands,
nullptr,
2220 nullptr, numGangs,
nullptr,
2221 nullptr, numWorkers,
2222 nullptr, vectorLength,
2223 nullptr, ifCond, selfCond,
2224 nullptr, reductionOperands, gangPrivateOperands,
2225 gangFirstPrivateOperands, dataClauseOperands,
2229void acc::ParallelOp::addNumWorkersOperand(
2232 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2233 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2234 getNumWorkersMutable()));
2236void acc::ParallelOp::addVectorLengthOperand(
2239 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2240 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2241 getVectorLengthMutable()));
2244void acc::ParallelOp::addAsyncOnly(
2246 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2247 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2250void acc::ParallelOp::addAsyncOperand(
2253 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2254 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2255 getAsyncOperandsMutable()));
2258void acc::ParallelOp::addNumGangsOperands(
2262 if (getNumGangsSegments())
2263 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2265 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2266 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2267 getNumGangsMutable(), segments));
2269 setNumGangsSegments(segments);
2271void acc::ParallelOp::addWaitOnly(
2273 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2274 effectiveDeviceTypes));
2276void acc::ParallelOp::addWaitOperands(
2281 if (getWaitOperandsSegments())
2282 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2284 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2285 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2286 getWaitOperandsMutable(), segments));
2287 setWaitOperandsSegments(segments);
2290 if (getHasWaitDevnumAttr())
2291 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2294 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2296 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2299void acc::ParallelOp::addPrivatization(
MLIRContext *context,
2300 mlir::acc::PrivateOp op,
2301 mlir::acc::PrivateRecipeOp recipe) {
2302 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2303 getPrivateOperandsMutable().append(op.getResult());
2306void acc::ParallelOp::addFirstPrivatization(
2307 MLIRContext *context, mlir::acc::FirstprivateOp op,
2308 mlir::acc::FirstprivateRecipeOp recipe) {
2309 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2310 getFirstprivateOperandsMutable().append(op.getResult());
2313void acc::ParallelOp::addReduction(
MLIRContext *context,
2314 mlir::acc::ReductionOp op,
2315 mlir::acc::ReductionRecipeOp recipe) {
2316 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2317 getReductionOperandsMutable().append(op.getResult());
2332 int32_t crtOperandsSize = operands.size();
2335 if (parser.parseOperand(operands.emplace_back()) ||
2336 parser.parseColonType(types.emplace_back()))
2341 seg.push_back(operands.size() - crtOperandsSize);
2351 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2352 parser.
getContext(), mlir::acc::DeviceType::None));
2358 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2365 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2366 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2367 p <<
" [" << attr <<
"]";
2372 std::optional<mlir::ArrayAttr> deviceTypes,
2373 std::optional<mlir::DenseI32ArrayAttr> segments) {
2375 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2377 llvm::interleaveComma(
2378 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2379 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2399 int32_t crtOperandsSize = operands.size();
2403 if (parser.parseOperand(operands.emplace_back()) ||
2404 parser.parseColonType(types.emplace_back()))
2410 seg.push_back(operands.size() - crtOperandsSize);
2420 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2421 parser.
getContext(), mlir::acc::DeviceType::None));
2427 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2436 std::optional<mlir::DenseI32ArrayAttr> segments) {
2438 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2440 llvm::interleaveComma(
2441 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2442 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2455 mlir::ArrayAttr &keywordOnly) {
2459 bool needCommaBeforeOperands =
false;
2463 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2464 parser.
getContext(), mlir::acc::DeviceType::None));
2465 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2472 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2479 needCommaBeforeOperands =
true;
2482 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2489 int32_t crtOperandsSize = operands.size();
2501 if (parser.parseOperand(operands.emplace_back()) ||
2502 parser.parseColonType(types.emplace_back()))
2508 seg.push_back(operands.size() - crtOperandsSize);
2518 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2519 parser.
getContext(), mlir::acc::DeviceType::None));
2526 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2527 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2529 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2537 if (attrs->size() != 1)
2539 if (
auto deviceTypeAttr =
2540 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2541 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2547 std::optional<mlir::ArrayAttr> deviceTypes,
2548 std::optional<mlir::DenseI32ArrayAttr> segments,
2549 std::optional<mlir::ArrayAttr> hasDevNum,
2550 std::optional<mlir::ArrayAttr> keywordOnly) {
2563 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2565 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2566 if (boolAttr && boolAttr.getValue())
2568 llvm::interleaveComma(
2569 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2570 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2587 if (parser.parseOperand(operands.emplace_back()) ||
2588 parser.parseColonType(types.emplace_back()))
2590 if (succeeded(parser.parseOptionalLSquare())) {
2591 if (parser.parseAttribute(attributes.emplace_back()) ||
2592 parser.parseRSquare())
2595 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2596 parser.getContext(), mlir::acc::DeviceType::None));
2603 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2610 std::optional<mlir::ArrayAttr> deviceTypes) {
2613 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2614 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2623 mlir::ArrayAttr &keywordOnlyDeviceType) {
2626 bool needCommaBeforeOperands =
false;
2630 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2631 parser.
getContext(), mlir::acc::DeviceType::None));
2632 keywordOnlyDeviceType =
2633 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2641 if (parser.parseAttribute(
2642 keywordOnlyDeviceTypeAttributes.emplace_back()))
2649 needCommaBeforeOperands =
true;
2652 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2657 if (parser.parseOperand(operands.emplace_back()) ||
2658 parser.parseColonType(types.emplace_back()))
2660 if (succeeded(parser.parseOptionalLSquare())) {
2661 if (parser.parseAttribute(attributes.emplace_back()) ||
2662 parser.parseRSquare())
2665 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2666 parser.getContext(), mlir::acc::DeviceType::None));
2672 if (
failed(parser.parseRParen()))
2677 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2684 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2686 if (operands.begin() == operands.end() &&
2702 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2703 mlir::Type &operandType, mlir::UnitAttr &attr) {
2706 attr = mlir::UnitAttr::get(parser.
getContext());
2716 if (failed(parser.
parseType(operandType)))
2726 std::optional<mlir::Value> operand,
2728 mlir::UnitAttr attr) {
2745 attr = mlir::UnitAttr::get(parser.
getContext());
2750 if (parser.parseOperand(operands.emplace_back()))
2758 if (parser.parseType(types.emplace_back()))
2773 mlir::UnitAttr attr) {
2778 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2780 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2786 mlir::acc::CombinedConstructsTypeAttr &attr) {
2788 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2789 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2791 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2792 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2794 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2795 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2798 "expected compute construct name");
2806 mlir::acc::CombinedConstructsTypeAttr attr) {
2808 switch (attr.getValue()) {
2809 case mlir::acc::CombinedConstructsType::KernelsLoop:
2812 case mlir::acc::CombinedConstructsType::ParallelLoop:
2815 case mlir::acc::CombinedConstructsType::SerialLoop:
2826unsigned SerialOp::getNumDataOperands() {
2827 return getReductionOperands().size() + getPrivateOperands().size() +
2828 getFirstprivateOperands().size() + getDataClauseOperands().size();
2831Value SerialOp::getDataOperand(
unsigned i) {
2833 numOptional += getIfCond() ? 1 : 0;
2834 numOptional += getSelfCond() ? 1 : 0;
2835 return getOperand(getWaitOperands().size() + numOptional + i);
2838bool acc::SerialOp::hasAsyncOnly() {
2839 return hasAsyncOnly(mlir::acc::DeviceType::None);
2842bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2847 return getAsyncValue(mlir::acc::DeviceType::None);
2850mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2855bool acc::SerialOp::hasWaitOnly() {
2856 return hasWaitOnly(mlir::acc::DeviceType::None);
2859bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2864 return getWaitValues(mlir::acc::DeviceType::None);
2868SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2870 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2871 getHasWaitDevnum(), deviceType);
2875 return getWaitDevnum(mlir::acc::DeviceType::None);
2878mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2880 getWaitOperandsSegments(), getHasWaitDevnum(),
2884LogicalResult acc::SerialOp::verify() {
2886 mlir::acc::PrivateRecipeOp>(
2887 *
this, getPrivateOperands(),
"private")))
2890 mlir::acc::FirstprivateRecipeOp>(
2891 *
this, getFirstprivateOperands(),
"firstprivate")))
2894 mlir::acc::ReductionRecipeOp>(
2895 *
this, getReductionOperands(),
"reduction")))
2899 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2900 getWaitOperandsDeviceTypeAttr(),
"wait")))
2904 getAsyncOperandsDeviceTypeAttr(),
2914void acc::SerialOp::addAsyncOnly(
2916 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2917 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2920void acc::SerialOp::addAsyncOperand(
2923 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2924 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2925 getAsyncOperandsMutable()));
2928void acc::SerialOp::addWaitOnly(
2930 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2931 effectiveDeviceTypes));
2933void acc::SerialOp::addWaitOperands(
2938 if (getWaitOperandsSegments())
2939 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2941 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2942 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2943 getWaitOperandsMutable(), segments));
2944 setWaitOperandsSegments(segments);
2947 if (getHasWaitDevnumAttr())
2948 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2951 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2953 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2956void acc::SerialOp::addPrivatization(
MLIRContext *context,
2957 mlir::acc::PrivateOp op,
2958 mlir::acc::PrivateRecipeOp recipe) {
2959 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2960 getPrivateOperandsMutable().append(op.getResult());
2963void acc::SerialOp::addFirstPrivatization(
2964 MLIRContext *context, mlir::acc::FirstprivateOp op,
2965 mlir::acc::FirstprivateRecipeOp recipe) {
2966 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2967 getFirstprivateOperandsMutable().append(op.getResult());
2970void acc::SerialOp::addReduction(
MLIRContext *context,
2971 mlir::acc::ReductionOp op,
2972 mlir::acc::ReductionRecipeOp recipe) {
2973 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2974 getReductionOperandsMutable().append(op.getResult());
2981unsigned KernelsOp::getNumDataOperands() {
2982 return getDataClauseOperands().size();
2985Value KernelsOp::getDataOperand(
unsigned i) {
2987 numOptional += getWaitOperands().size();
2988 numOptional += getNumGangs().size();
2989 numOptional += getNumWorkers().size();
2990 numOptional += getVectorLength().size();
2991 numOptional += getIfCond() ? 1 : 0;
2992 numOptional += getSelfCond() ? 1 : 0;
2993 return getOperand(numOptional + i);
2996bool acc::KernelsOp::hasAsyncOnly() {
2997 return hasAsyncOnly(mlir::acc::DeviceType::None);
3000bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3005 return getAsyncValue(mlir::acc::DeviceType::None);
3008mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3014 return getNumWorkersValue(mlir::acc::DeviceType::None);
3018acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
3023mlir::Value acc::KernelsOp::getVectorLengthValue() {
3024 return getVectorLengthValue(mlir::acc::DeviceType::None);
3028acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
3030 getVectorLength(), deviceType);
3034 return getNumGangsValues(mlir::acc::DeviceType::None);
3038KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
3040 getNumGangsSegments(), deviceType);
3043bool acc::KernelsOp::hasWaitOnly() {
3044 return hasWaitOnly(mlir::acc::DeviceType::None);
3047bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3052 return getWaitValues(mlir::acc::DeviceType::None);
3056KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3058 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3059 getHasWaitDevnum(), deviceType);
3063 return getWaitDevnum(mlir::acc::DeviceType::None);
3066mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3068 getWaitOperandsSegments(), getHasWaitDevnum(),
3072LogicalResult acc::KernelsOp::verify() {
3074 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
3075 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
3079 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3080 getWaitOperandsDeviceTypeAttr(),
"wait")))
3084 getNumWorkersDeviceTypeAttr(),
3089 getVectorLengthDeviceTypeAttr(),
3094 getAsyncOperandsDeviceTypeAttr(),
3104void acc::KernelsOp::addPrivatization(
MLIRContext *context,
3105 mlir::acc::PrivateOp op,
3106 mlir::acc::PrivateRecipeOp recipe) {
3107 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3108 getPrivateOperandsMutable().append(op.getResult());
3111void acc::KernelsOp::addFirstPrivatization(
3112 MLIRContext *context, mlir::acc::FirstprivateOp op,
3113 mlir::acc::FirstprivateRecipeOp recipe) {
3114 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3115 getFirstprivateOperandsMutable().append(op.getResult());
3118void acc::KernelsOp::addReduction(
MLIRContext *context,
3119 mlir::acc::ReductionOp op,
3120 mlir::acc::ReductionRecipeOp recipe) {
3121 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3122 getReductionOperandsMutable().append(op.getResult());
3125void acc::KernelsOp::addNumWorkersOperand(
3128 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3129 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3130 getNumWorkersMutable()));
3133void acc::KernelsOp::addVectorLengthOperand(
3136 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3137 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3138 getVectorLengthMutable()));
3140void acc::KernelsOp::addAsyncOnly(
3142 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3143 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3146void acc::KernelsOp::addAsyncOperand(
3149 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3150 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3151 getAsyncOperandsMutable()));
3154void acc::KernelsOp::addNumGangsOperands(
3158 if (getNumGangsSegmentsAttr())
3159 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3161 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3162 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3163 getNumGangsMutable(), segments));
3165 setNumGangsSegments(segments);
3168void acc::KernelsOp::addWaitOnly(
3170 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3171 effectiveDeviceTypes));
3173void acc::KernelsOp::addWaitOperands(
3178 if (getWaitOperandsSegments())
3179 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3181 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3182 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3183 getWaitOperandsMutable(), segments));
3184 setWaitOperandsSegments(segments);
3187 if (getHasWaitDevnumAttr())
3188 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3191 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3193 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3200LogicalResult acc::HostDataOp::verify() {
3201 if (getDataClauseOperands().empty())
3202 return emitError(
"at least one operand must appear on the host_data "
3206 for (
mlir::Value operand : getDataClauseOperands()) {
3208 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3210 return emitError(
"expect data entry operation as defining op");
3213 if (!seenVars.insert(useDeviceOp.getVar()).second)
3214 return emitError(
"duplicate use_device variable");
3221 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3228void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
3230 results.
add<RemoveEmptyKernelEnvironment>(context);
3242 bool &needCommaBetweenValues,
bool &newValue) {
3249 attributes.push_back(gangArgType);
3250 needCommaBetweenValues =
true;
3261 mlir::ArrayAttr &gangOnlyDeviceType) {
3266 bool needCommaBetweenValues =
false;
3267 bool needCommaBeforeOperands =
false;
3271 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3272 parser.
getContext(), mlir::acc::DeviceType::None));
3273 gangOnlyDeviceType =
3274 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
3282 if (parser.parseAttribute(
3283 gangOnlyDeviceTypeAttributes.emplace_back()))
3290 needCommaBeforeOperands =
true;
3293 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3294 mlir::acc::GangArgType::Num);
3295 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3296 mlir::acc::GangArgType::Dim);
3297 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3298 parser.
getContext(), mlir::acc::GangArgType::Static);
3301 if (needCommaBeforeOperands) {
3302 needCommaBeforeOperands =
false;
3309 int32_t crtOperandsSize = gangOperands.size();
3311 bool newValue =
false;
3312 bool needValue =
false;
3313 if (needCommaBetweenValues) {
3321 gangOperands, gangOperandsType,
3322 gangArgTypeAttributes, argNum,
3323 needCommaBetweenValues, newValue)))
3326 gangOperands, gangOperandsType,
3327 gangArgTypeAttributes, argDim,
3328 needCommaBetweenValues, newValue)))
3330 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3331 gangOperands, gangOperandsType,
3332 gangArgTypeAttributes, argStatic,
3333 needCommaBetweenValues, newValue)))
3336 if (!newValue && needValue) {
3338 "new value expected after comma");
3346 if (gangOperands.empty())
3349 "expect at least one of num, dim or static values");
3355 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
3359 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3360 parser.
getContext(), mlir::acc::DeviceType::None));
3363 seg.push_back(gangOperands.size() - crtOperandsSize);
3371 gangArgTypeAttributes.end());
3372 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
3373 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
3376 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3377 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
3385 std::optional<mlir::ArrayAttr> gangArgTypes,
3386 std::optional<mlir::ArrayAttr> deviceTypes,
3387 std::optional<mlir::DenseI32ArrayAttr> segments,
3388 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3390 if (operands.begin() == operands.end() &&
3405 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
3407 llvm::interleaveComma(
3408 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
3409 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3410 (*gangArgTypes)[opIdx]);
3411 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3412 p << LoopOp::getGangNumKeyword();
3413 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3414 p << LoopOp::getGangDimKeyword();
3415 else if (gangArgTypeAttr.getValue() ==
3416 mlir::acc::GangArgType::Static)
3417 p << LoopOp::getGangStaticKeyword();
3418 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
3429 std::optional<mlir::ArrayAttr> segments,
3430 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3433 for (
auto attr : *segments) {
3434 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3435 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3443static std::optional<mlir::acc::DeviceType>
3445 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3447 return std::nullopt;
3448 for (
auto attr : deviceTypes) {
3449 auto deviceTypeAttr =
3450 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3451 if (!deviceTypeAttr)
3452 return mlir::acc::DeviceType::None;
3453 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3454 return deviceTypeAttr.getValue();
3456 return std::nullopt;
3459LogicalResult acc::LoopOp::verify() {
3460 if (getUpperbound().size() != getStep().size())
3461 return emitError() <<
"number of upperbounds expected to be the same as "
3464 if (getUpperbound().size() != getLowerbound().size())
3465 return emitError() <<
"number of upperbounds expected to be the same as "
3466 "number of lowerbounds";
3468 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3469 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3470 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3471 <<
" as upperbound size";
3474 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3475 return emitOpError() <<
"collapse device_type attr must be define when"
3476 <<
" collapse attr is present";
3478 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3479 getCollapseAttr().getValue().size() !=
3480 getCollapseDeviceTypeAttr().getValue().size())
3481 return emitOpError() <<
"collapse attribute count must match collapse"
3482 <<
" device_type count";
3483 if (
auto duplicateDeviceType =
checkDeviceTypes(getCollapseDeviceTypeAttr()))
3485 << acc::stringifyDeviceType(*duplicateDeviceType)
3486 <<
"` found in collapseDeviceType attribute";
3489 if (!getGangOperands().empty()) {
3490 if (!getGangOperandsArgType())
3491 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3492 <<
" when gang operands are present";
3494 if (getGangOperands().size() !=
3495 getGangOperandsArgTypeAttr().getValue().size())
3496 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3497 <<
" gangOperands count";
3499 if (getGangAttr()) {
3502 << acc::stringifyDeviceType(*duplicateDeviceType)
3503 <<
"` found in gang attribute";
3507 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3508 getGangOperandsDeviceTypeAttr(),
"gang")))
3514 << acc::stringifyDeviceType(*duplicateDeviceType)
3515 <<
"` found in worker attribute";
3516 if (
auto duplicateDeviceType =
3519 << acc::stringifyDeviceType(*duplicateDeviceType)
3520 <<
"` found in workerNumOperandsDeviceType attribute";
3522 getWorkerNumOperandsDeviceTypeAttr(),
3529 << acc::stringifyDeviceType(*duplicateDeviceType)
3530 <<
"` found in vector attribute";
3531 if (
auto duplicateDeviceType =
3534 << acc::stringifyDeviceType(*duplicateDeviceType)
3535 <<
"` found in vectorOperandsDeviceType attribute";
3537 getVectorOperandsDeviceTypeAttr(),
3542 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3543 getTileOperandsDeviceTypeAttr(),
"tile")))
3547 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3551 return emitError() <<
"only one of auto, independent, seq can be present "
3557 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3558 return attr.getValue() == mlir::acc::DeviceType::None;
3560 bool hasDefaultSeq =
3562 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3565 bool hasDefaultIndependent =
3566 getIndependentAttr()
3568 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3571 bool hasDefaultAuto =
3573 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3576 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3578 <<
"at least one of auto, independent, seq must be present";
3583 for (
auto attr : getSeqAttr()) {
3584 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3585 if (hasVector(deviceTypeAttr.getValue()) ||
3586 getVectorValue(deviceTypeAttr.getValue()) ||
3587 hasWorker(deviceTypeAttr.getValue()) ||
3588 getWorkerValue(deviceTypeAttr.getValue()) ||
3589 hasGang(deviceTypeAttr.getValue()) ||
3590 getGangValue(mlir::acc::GangArgType::Num,
3591 deviceTypeAttr.getValue()) ||
3592 getGangValue(mlir::acc::GangArgType::Dim,
3593 deviceTypeAttr.getValue()) ||
3594 getGangValue(mlir::acc::GangArgType::Static,
3595 deviceTypeAttr.getValue()))
3596 return emitError() <<
"gang, worker or vector cannot appear with seq";
3601 mlir::acc::PrivateRecipeOp>(
3602 *
this, getPrivateOperands(),
"private")))
3606 mlir::acc::FirstprivateRecipeOp>(
3607 *
this, getFirstprivateOperands(),
"firstprivate")))
3611 mlir::acc::ReductionRecipeOp>(
3612 *
this, getReductionOperands(),
"reduction")))
3615 if (getCombined().has_value() &&
3616 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3617 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3618 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3619 return emitError(
"unexpected combined constructs attribute");
3623 if (getRegion().empty())
3624 return emitError(
"expected non-empty body.");
3626 if (getUnstructured()) {
3627 if (!isContainerLike())
3629 "unstructured acc.loop must not have induction variables");
3630 }
else if (isContainerLike()) {
3634 uint64_t collapseCount = getCollapseValue().value_or(1);
3635 if (getCollapseAttr()) {
3636 for (
auto collapseEntry : getCollapseAttr()) {
3637 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3638 if (intAttr.getValue().getZExtValue() > collapseCount)
3639 collapseCount = intAttr.getValue().getZExtValue();
3647 bool foundSibling =
false;
3649 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3651 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3653 foundSibling =
true;
3658 expectedParent = op;
3661 if (collapseCount == 0)
3667 return emitError(
"found sibling loops inside container-like acc.loop");
3668 if (collapseCount != 0)
3669 return emitError(
"failed to find enough loop-like operations inside "
3670 "container-like acc.loop");
3676unsigned LoopOp::getNumDataOperands() {
3677 return getReductionOperands().size() + getPrivateOperands().size() +
3678 getFirstprivateOperands().size();
3681Value LoopOp::getDataOperand(
unsigned i) {
3682 unsigned numOptional =
3683 getLowerbound().size() + getUpperbound().size() + getStep().size();
3684 numOptional += getGangOperands().size();
3685 numOptional += getVectorOperands().size();
3686 numOptional += getWorkerNumOperands().size();
3687 numOptional += getTileOperands().size();
3688 numOptional += getCacheOperands().size();
3689 return getOperand(numOptional + i);
3692bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3694bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3698bool LoopOp::hasIndependent() {
3699 return hasIndependent(mlir::acc::DeviceType::None);
3702bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3706bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3708bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3713 return getVectorValue(mlir::acc::DeviceType::None);
3716mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3718 getVectorOperands(), deviceType);
3721bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3723bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3728 return getWorkerValue(mlir::acc::DeviceType::None);
3731mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3733 getWorkerNumOperands(), deviceType);
3736bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3738bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3743 return getTileValues(mlir::acc::DeviceType::None);
3747LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3749 getTileOperandsSegments(), deviceType);
3752std::optional<int64_t> LoopOp::getCollapseValue() {
3753 return getCollapseValue(mlir::acc::DeviceType::None);
3756std::optional<int64_t>
3757LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3758 if (!getCollapseAttr())
3759 return std::nullopt;
3760 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3762 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3763 return intAttr.getValue().getZExtValue();
3765 return std::nullopt;
3768mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3769 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3772mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3773 mlir::acc::DeviceType deviceType) {
3774 if (getGangOperands().empty())
3776 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3777 int32_t nbOperandsBefore = 0;
3778 for (
unsigned i = 0; i < *pos; ++i)
3779 nbOperandsBefore += (*getGangOperandsSegments())[i];
3782 .drop_front(nbOperandsBefore)
3783 .take_front((*getGangOperandsSegments())[*pos]);
3785 int32_t argTypeIdx = nbOperandsBefore;
3786 for (
auto value : values) {
3787 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3788 (*getGangOperandsArgType())[argTypeIdx]);
3789 if (gangArgTypeAttr.getValue() == gangArgType)
3797bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3799bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3804 return {&getRegion()};
3848 if (!regionArgs.empty()) {
3849 p << acc::LoopOp::getControlKeyword() <<
"(";
3850 llvm::interleaveComma(regionArgs, p,
3852 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3853 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3854 <<
" : " << stepType <<
") ";
3861 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3862 effectiveDeviceTypes));
3865void acc::LoopOp::addIndependent(
3867 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3868 context, getIndependentAttr(), effectiveDeviceTypes));
3873 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3874 effectiveDeviceTypes));
3877void acc::LoopOp::setCollapseForDeviceTypes(
3879 llvm::APInt value) {
3883 assert((getCollapseAttr() ==
nullptr) ==
3884 (getCollapseDeviceTypeAttr() ==
nullptr));
3885 assert(value.getBitWidth() == 64);
3887 if (getCollapseAttr()) {
3888 for (
const auto &existing :
3889 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3890 newValues.push_back(std::get<0>(existing));
3891 newDeviceTypes.push_back(std::get<1>(existing));
3895 if (effectiveDeviceTypes.empty()) {
3898 newValues.push_back(
3899 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3900 newDeviceTypes.push_back(
3901 acc::DeviceTypeAttr::get(context, DeviceType::None));
3903 for (DeviceType dt : effectiveDeviceTypes) {
3904 newValues.push_back(
3905 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3906 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3910 setCollapseAttr(ArrayAttr::get(context, newValues));
3911 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3914void acc::LoopOp::setTileForDeviceTypes(
3918 if (getTileOperandsSegments())
3919 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3921 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3922 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3923 getTileOperandsMutable(), segments));
3925 setTileOperandsSegments(segments);
3928void acc::LoopOp::addVectorOperand(
3931 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3932 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3933 newValue, getVectorOperandsMutable()));
3936void acc::LoopOp::addEmptyVector(
3938 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3939 effectiveDeviceTypes));
3942void acc::LoopOp::addWorkerNumOperand(
3945 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3946 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3947 newValue, getWorkerNumOperandsMutable()));
3950void acc::LoopOp::addEmptyWorker(
3952 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3953 effectiveDeviceTypes));
3956void acc::LoopOp::addEmptyGang(
3958 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3959 effectiveDeviceTypes));
3962bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3963 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3964 return attr.getValue() == dt;
3966 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3967 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3970 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3972 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3974 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3980bool acc::LoopOp::hasDefaultGangWorkerVector() {
3981 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3982 hasGang() || getGangValue(GangArgType::Num) ||
3983 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3987acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3988 if (hasSeq(deviceType))
3989 return LoopParMode::loop_seq;
3990 if (hasAuto(deviceType))
3991 return LoopParMode::loop_auto;
3992 if (hasIndependent(deviceType))
3993 return LoopParMode::loop_independent;
3995 return LoopParMode::loop_seq;
3997 return LoopParMode::loop_auto;
3998 assert(hasIndependent() &&
3999 "loop must have default auto, seq, or independent");
4000 return LoopParMode::loop_independent;
4003void acc::LoopOp::addGangOperands(
4008 getGangOperandsSegments())
4009 llvm::copy(*existingSegments, std::back_inserter(segments));
4011 unsigned beforeCount = segments.size();
4013 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4014 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
4015 getGangOperandsMutable(), segments));
4017 setGangOperandsSegments(segments);
4024 unsigned numAdded = segments.size() - beforeCount;
4028 if (getGangOperandsArgTypeAttr())
4029 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
4031 for (
auto i : llvm::index_range(0u, numAdded)) {
4032 llvm::transform(argTypes, std::back_inserter(gangTypes),
4033 [=](mlir::acc::GangArgType gangTy) {
4034 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
4039 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
4043void acc::LoopOp::addPrivatization(
MLIRContext *context,
4044 mlir::acc::PrivateOp op,
4045 mlir::acc::PrivateRecipeOp recipe) {
4046 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4047 getPrivateOperandsMutable().append(op.getResult());
4050void acc::LoopOp::addFirstPrivatization(
4051 MLIRContext *context, mlir::acc::FirstprivateOp op,
4052 mlir::acc::FirstprivateRecipeOp recipe) {
4053 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4054 getFirstprivateOperandsMutable().append(op.getResult());
4057void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
4058 mlir::acc::ReductionRecipeOp recipe) {
4059 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4060 getReductionOperandsMutable().append(op.getResult());
4067LogicalResult acc::DataOp::verify() {
4072 return emitError(
"at least one operand or the default attribute "
4073 "must appear on the data operation");
4075 for (
mlir::Value operand : getDataClauseOperands())
4076 if (isa<BlockArgument>(operand) ||
4077 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4078 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
4079 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
4080 operand.getDefiningOp()))
4081 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4090unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
4092Value DataOp::getDataOperand(
unsigned i) {
4093 unsigned numOptional = getIfCond() ? 1 : 0;
4095 numOptional += getWaitOperands().size();
4096 return getOperand(numOptional + i);
4099bool acc::DataOp::hasAsyncOnly() {
4100 return hasAsyncOnly(mlir::acc::DeviceType::None);
4103bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4108 return getAsyncValue(mlir::acc::DeviceType::None);
4111mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4116bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
4118bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4123 return getWaitValues(mlir::acc::DeviceType::None);
4127DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4129 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4130 getHasWaitDevnum(), deviceType);
4134 return getWaitDevnum(mlir::acc::DeviceType::None);
4137mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4139 getWaitOperandsSegments(), getHasWaitDevnum(),
4143void acc::DataOp::addAsyncOnly(
4145 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4146 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4149void acc::DataOp::addAsyncOperand(
4152 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4153 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4154 getAsyncOperandsMutable()));
4157void acc::DataOp::addWaitOnly(
MLIRContext *context,
4159 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4160 effectiveDeviceTypes));
4163void acc::DataOp::addWaitOperands(
4168 if (getWaitOperandsSegments())
4169 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4171 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4172 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4173 getWaitOperandsMutable(), segments));
4174 setWaitOperandsSegments(segments);
4177 if (getHasWaitDevnumAttr())
4178 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4181 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4183 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4190LogicalResult acc::ExitDataOp::verify() {
4194 if (getDataClauseOperands().empty())
4195 return emitError(
"at least one operand must be present in dataOperands on "
4196 "the exit data operation");
4200 if (getAsyncOperand() && getAsync())
4201 return emitError(
"async attribute cannot appear with asyncOperand");
4205 if (!getWaitOperands().empty() && getWait())
4206 return emitError(
"wait attribute cannot appear with waitOperands");
4208 if (getWaitDevnum() && getWaitOperands().empty())
4209 return emitError(
"wait_devnum cannot appear without waitOperands");
4214unsigned ExitDataOp::getNumDataOperands() {
4215 return getDataClauseOperands().size();
4218Value ExitDataOp::getDataOperand(
unsigned i) {
4219 unsigned numOptional = getIfCond() ? 1 : 0;
4220 numOptional += getAsyncOperand() ? 1 : 0;
4221 numOptional += getWaitDevnum() ? 1 : 0;
4222 return getOperand(getWaitOperands().size() + numOptional + i);
4227 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
4230void ExitDataOp::addAsyncOnly(
MLIRContext *context,
4232 assert(effectiveDeviceTypes.empty());
4233 assert(!getAsyncAttr());
4234 assert(!getAsyncOperand());
4236 setAsyncAttr(mlir::UnitAttr::get(context));
4239void ExitDataOp::addAsyncOperand(
4242 assert(effectiveDeviceTypes.empty());
4243 assert(!getAsyncAttr());
4244 assert(!getAsyncOperand());
4246 getAsyncOperandMutable().append(newValue);
4251 assert(effectiveDeviceTypes.empty());
4252 assert(!getWaitAttr());
4253 assert(getWaitOperands().empty());
4254 assert(!getWaitDevnum());
4256 setWaitAttr(mlir::UnitAttr::get(context));
4259void ExitDataOp::addWaitOperands(
4262 assert(effectiveDeviceTypes.empty());
4263 assert(!getWaitAttr());
4264 assert(getWaitOperands().empty());
4265 assert(!getWaitDevnum());
4270 getWaitDevnumMutable().append(newValues.front());
4271 newValues = newValues.drop_front();
4274 getWaitOperandsMutable().append(newValues);
4281LogicalResult acc::EnterDataOp::verify() {
4285 if (getDataClauseOperands().empty())
4286 return emitError(
"at least one operand must be present in dataOperands on "
4287 "the enter data operation");
4291 if (getAsyncOperand() && getAsync())
4292 return emitError(
"async attribute cannot appear with asyncOperand");
4296 if (!getWaitOperands().empty() && getWait())
4297 return emitError(
"wait attribute cannot appear with waitOperands");
4299 if (getWaitDevnum() && getWaitOperands().empty())
4300 return emitError(
"wait_devnum cannot appear without waitOperands");
4302 for (
mlir::Value operand : getDataClauseOperands())
4303 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4304 operand.getDefiningOp()))
4305 return emitError(
"expect data entry operation as defining op");
4310unsigned EnterDataOp::getNumDataOperands() {
4311 return getDataClauseOperands().size();
4314Value EnterDataOp::getDataOperand(
unsigned i) {
4315 unsigned numOptional = getIfCond() ? 1 : 0;
4316 numOptional += getAsyncOperand() ? 1 : 0;
4317 numOptional += getWaitDevnum() ? 1 : 0;
4318 return getOperand(getWaitOperands().size() + numOptional + i);
4323 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
4326void EnterDataOp::addAsyncOnly(
4328 assert(effectiveDeviceTypes.empty());
4329 assert(!getAsyncAttr());
4330 assert(!getAsyncOperand());
4332 setAsyncAttr(mlir::UnitAttr::get(context));
4335void EnterDataOp::addAsyncOperand(
4338 assert(effectiveDeviceTypes.empty());
4339 assert(!getAsyncAttr());
4340 assert(!getAsyncOperand());
4342 getAsyncOperandMutable().append(newValue);
4345void EnterDataOp::addWaitOnly(
MLIRContext *context,
4347 assert(effectiveDeviceTypes.empty());
4348 assert(!getWaitAttr());
4349 assert(getWaitOperands().empty());
4350 assert(!getWaitDevnum());
4352 setWaitAttr(mlir::UnitAttr::get(context));
4355void EnterDataOp::addWaitOperands(
4358 assert(effectiveDeviceTypes.empty());
4359 assert(!getWaitAttr());
4360 assert(getWaitOperands().empty());
4361 assert(!getWaitDevnum());
4366 getWaitDevnumMutable().append(newValues.front());
4367 newValues = newValues.drop_front();
4370 getWaitOperandsMutable().append(newValues);
4377LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
4383LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
4389LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4396 if (
Value writeVal = op.getWriteOpVal()) {
4405LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
4407LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4413AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4414 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4416 return dyn_cast<AtomicReadOp>(getSecondOp());
4419AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4420 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4422 return dyn_cast<AtomicWriteOp>(getSecondOp());
4425AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4426 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4428 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4431LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
4437template <
typename Op>
4440 bool requireAtLeastOneOperand =
true) {
4441 if (operands.empty() && requireAtLeastOneOperand)
4444 "at least one operand must appear on the declare operation");
4447 if (isa<BlockArgument>(operand) ||
4448 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4449 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4450 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4451 operand.getDefiningOp()))
4453 "expect valid declare data entry operation or acc.getdeviceptr "
4457 assert(var &&
"declare operands can only be data entry operations which "
4460 std::optional<mlir::acc::DataClause> dataClauseOptional{
4462 assert(dataClauseOptional.has_value() &&
4463 "declare operands can only be data entry operations which must have "
4465 (
void)dataClauseOptional;
4471LogicalResult acc::DeclareEnterOp::verify() {
4479LogicalResult acc::DeclareExitOp::verify() {
4490LogicalResult acc::DeclareOp::verify() {
4499 acc::DeviceType dtype) {
4500 unsigned parallelism = 0;
4501 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4502 parallelism += op.hasWorker(dtype) ? 1 : 0;
4503 parallelism += op.hasVector(dtype) ? 1 : 0;
4504 parallelism += op.hasSeq(dtype) ? 1 : 0;
4508LogicalResult acc::RoutineOp::verify() {
4509 unsigned baseParallelism =
4512 if (baseParallelism > 1)
4513 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4514 "be present at the same time";
4516 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4518 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4519 if (dtype == acc::DeviceType::None)
4523 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4524 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4525 "be present at the same time for device_type `"
4526 << acc::stringifyDeviceType(dtype) <<
"`";
4533 mlir::ArrayAttr &bindIdName,
4534 mlir::ArrayAttr &bindStrName,
4535 mlir::ArrayAttr &deviceIdTypes,
4536 mlir::ArrayAttr &deviceStrTypes) {
4543 mlir::Attribute newAttr;
4544 bool isSymbolRefAttr;
4545 auto parseResult = parser.parseAttribute(newAttr);
4546 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4547 bindIdNameAttrs.push_back(symbolRefAttr);
4548 isSymbolRefAttr = true;
4549 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4550 bindStrNameAttrs.push_back(stringAttr);
4551 isSymbolRefAttr =
false;
4556 if (isSymbolRefAttr) {
4557 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4558 parser.getContext(), mlir::acc::DeviceType::None));
4560 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4561 parser.getContext(), mlir::acc::DeviceType::None));
4564 if (isSymbolRefAttr) {
4565 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4566 parser.parseRSquare())
4569 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4570 parser.parseRSquare())
4578 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4579 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4580 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4581 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4587 std::optional<mlir::ArrayAttr> bindIdName,
4588 std::optional<mlir::ArrayAttr> bindStrName,
4589 std::optional<mlir::ArrayAttr> deviceIdTypes,
4590 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4597 allBindNames.append(bindIdName->begin(), bindIdName->end());
4598 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4603 allBindNames.append(bindStrName->begin(), bindStrName->end());
4604 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4608 if (!allBindNames.empty())
4609 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4610 [&](
const auto &pair) {
4611 p << std::get<0>(pair);
4617 mlir::ArrayAttr &gang,
4618 mlir::ArrayAttr &gangDim,
4619 mlir::ArrayAttr &gangDimDeviceTypes) {
4622 gangDimDeviceTypeAttrs;
4623 bool needCommaBeforeOperands =
false;
4627 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4628 parser.
getContext(), mlir::acc::DeviceType::None));
4629 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4636 if (parser.parseAttribute(gangAttrs.emplace_back()))
4643 needCommaBeforeOperands =
true;
4646 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4650 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4651 parser.parseColon() ||
4652 parser.parseAttribute(gangDimAttrs.emplace_back()))
4654 if (succeeded(parser.parseOptionalLSquare())) {
4655 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4656 parser.parseRSquare())
4659 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4660 parser.getContext(), mlir::acc::DeviceType::None));
4666 if (
failed(parser.parseRParen()))
4669 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4670 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4671 gangDimDeviceTypes =
4672 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4678 std::optional<mlir::ArrayAttr> gang,
4679 std::optional<mlir::ArrayAttr> gangDim,
4680 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4683 gang->size() == 1) {
4684 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4685 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4697 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4698 [&](
const auto &pair) {
4699 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4700 p << std::get<0>(pair);
4708 mlir::ArrayAttr &deviceTypes) {
4712 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4713 parser.
getContext(), mlir::acc::DeviceType::None));
4714 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4721 if (parser.parseAttribute(attributes.emplace_back()))
4729 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4735 std::optional<mlir::ArrayAttr> deviceTypes) {
4738 auto deviceTypeAttr =
4739 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4740 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4749 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4755bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4757bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4761bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4763bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4767bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4769bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4773std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4774RoutineOp::getBindNameValue() {
4775 return getBindNameValue(mlir::acc::DeviceType::None);
4778std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4779RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4782 return std::nullopt;
4785 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4786 auto attr = (*getBindIdName())[*pos];
4787 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4788 assert(symbolRefAttr &&
"expected SymbolRef");
4789 return symbolRefAttr;
4792 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4793 auto attr = (*getBindStrName())[*pos];
4794 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4795 assert(stringAttr &&
"expected String");
4799 return std::nullopt;
4802bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4804bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4808std::optional<int64_t> RoutineOp::getGangDimValue() {
4809 return getGangDimValue(mlir::acc::DeviceType::None);
4812std::optional<int64_t>
4813RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4815 return std::nullopt;
4816 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4817 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4818 return intAttr.getInt();
4820 return std::nullopt;
4825 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4826 effectiveDeviceTypes));
4831 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4832 effectiveDeviceTypes));
4837 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4838 effectiveDeviceTypes));
4843 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4844 effectiveDeviceTypes));
4853 if (getGangDimAttr())
4854 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4855 if (getGangDimDeviceTypeAttr())
4856 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4858 assert(dimValues.size() == deviceTypes.size());
4860 if (effectiveDeviceTypes.empty()) {
4861 dimValues.push_back(
4862 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4863 deviceTypes.push_back(
4864 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4866 for (DeviceType dt : effectiveDeviceTypes) {
4867 dimValues.push_back(
4868 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4869 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4872 assert(dimValues.size() == deviceTypes.size());
4874 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4875 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4878void RoutineOp::addBindStrName(
MLIRContext *context,
4880 mlir::StringAttr val) {
4881 unsigned before = getBindStrNameDeviceTypeAttr()
4882 ? getBindStrNameDeviceTypeAttr().size()
4885 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4886 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4887 unsigned after = getBindStrNameDeviceTypeAttr().size();
4890 if (getBindStrNameAttr())
4891 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4892 for (
unsigned i = 0; i < after - before; ++i)
4893 vals.push_back(val);
4895 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4898void RoutineOp::addBindIDName(
MLIRContext *context,
4900 mlir::SymbolRefAttr val) {
4902 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4904 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4905 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4906 unsigned after = getBindIdNameDeviceTypeAttr().size();
4909 if (getBindIdNameAttr())
4910 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4911 for (
unsigned i = 0; i < after - before; ++i)
4912 vals.push_back(val);
4914 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4921LogicalResult acc::InitOp::verify() {
4925 return emitOpError(
"cannot be nested in a compute operation");
4929void acc::InitOp::addDeviceType(
MLIRContext *context,
4930 mlir::acc::DeviceType deviceType) {
4932 if (getDeviceTypesAttr())
4933 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4935 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4936 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4943LogicalResult acc::ShutdownOp::verify() {
4947 return emitOpError(
"cannot be nested in a compute operation");
4951void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4952 mlir::acc::DeviceType deviceType) {
4954 if (getDeviceTypesAttr())
4955 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4957 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4958 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4965LogicalResult acc::SetOp::verify() {
4969 return emitOpError(
"cannot be nested in a compute operation");
4970 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4971 return emitOpError(
"at least one default_async, device_num, or device_type "
4972 "operand must appear");
4980LogicalResult acc::UpdateOp::verify() {
4982 if (getDataClauseOperands().empty())
4983 return emitError(
"at least one value must be present in dataOperands");
4986 getAsyncOperandsDeviceTypeAttr(),
4991 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4992 getWaitOperandsDeviceTypeAttr(),
"wait")))
4998 for (
mlir::Value operand : getDataClauseOperands())
4999 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
5000 operand.getDefiningOp()))
5001 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
5007unsigned UpdateOp::getNumDataOperands() {
5008 return getDataClauseOperands().size();
5011Value UpdateOp::getDataOperand(
unsigned i) {
5013 numOptional += getIfCond() ? 1 : 0;
5014 return getOperand(getWaitOperands().size() + numOptional + i);
5019 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
5022bool UpdateOp::hasAsyncOnly() {
5023 return hasAsyncOnly(mlir::acc::DeviceType::None);
5026bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
5031 return getAsyncValue(mlir::acc::DeviceType::None);
5034mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
5044bool UpdateOp::hasWaitOnly() {
5045 return hasWaitOnly(mlir::acc::DeviceType::None);
5048bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
5053 return getWaitValues(mlir::acc::DeviceType::None);
5057UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
5059 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
5060 getHasWaitDevnum(), deviceType);
5064 return getWaitDevnum(mlir::acc::DeviceType::None);
5067mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
5069 getWaitOperandsSegments(), getHasWaitDevnum(),
5075 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
5076 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
5079void UpdateOp::addAsyncOperand(
5082 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5083 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
5084 getAsyncOperandsMutable()));
5089 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
5090 effectiveDeviceTypes));
5093void UpdateOp::addWaitOperands(
5098 if (getWaitOperandsSegments())
5099 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
5101 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5102 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
5103 getWaitOperandsMutable(), segments));
5104 setWaitOperandsSegments(segments);
5107 if (getHasWaitDevnumAttr())
5108 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5111 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
5113 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5120LogicalResult acc::WaitOp::verify() {
5123 if (getAsyncOperand() && getAsync())
5124 return emitError(
"async attribute cannot appear with asyncOperand");
5126 if (getWaitDevnum() && getWaitOperands().empty())
5127 return emitError(
"wait_devnum cannot appear without waitOperands");
5132#define GET_OP_CLASSES
5133#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5135#define GET_ATTRDEF_CLASSES
5136#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5138#define GET_TYPEDEF_CLASSES
5139#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5150 .Case<ACC_DATA_ENTRY_OPS>(
5151 [&](
auto entry) {
return entry.getVarPtr(); })
5152 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5153 [&](
auto exit) {
return exit.getVarPtr(); })
5171 [&](
auto entry) {
return entry.getVarType(); })
5172 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5173 [&](
auto exit) {
return exit.getVarType(); })
5183 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5184 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
5194 [&](
auto dataClause) {
return dataClause.getAccVar(); })
5203 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
5213 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5215 dataClause.getBounds().begin(), dataClause.getBounds().end());
5227 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5229 dataClause.getAsyncOperands().begin(),
5230 dataClause.getAsyncOperands().end());
5241 return dataClause.getAsyncOperandsDeviceTypeAttr();
5249 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
5256 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
5263std::optional<mlir::acc::DataClause>
5268 .Case<ACC_DATA_ENTRY_OPS>(
5269 [&](
auto entry) {
return entry.getDataClause(); })
5277 [&](
auto entry) {
return entry.getImplicit(); })
5286 [&](
auto entry) {
return entry.getDataClauseOperands(); })
5288 return dataOperands;
5296 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
5298 return dataOperands;
5305 [&](
auto entry) {
return entry.getRecipeAttr(); })
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
static bool isComputeOperation(Operation *op)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
unsigned size() const
Returns the current size of the range.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static CurrentDeviceIdResource * get()
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.