25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
41static bool isScalarLikeType(
Type type) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
56struct MemRefPointerLikeModel
57 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
59 return cast<T>(pointer).getElementType();
62 mlir::acc::VariableTypeCategory
65 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
72 return mlir::acc::VariableTypeCategory::uncategorized;
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
81 return mlir::acc::VariableTypeCategory::uncategorized;
85 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree)
const {
92 auto memrefTy = cast<MemRefType>(pointer);
96 if (memrefTy.hasStaticShape()) {
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
105 if (originalVar && originalVar.
getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
132 Type varType)
const {
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc =
nullptr;
143 while (currentValue) {
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
152 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
158 if (
auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
172 if (isa<memref::AllocaOp>(originalAlloc)) {
176 if (isa<memref::AllocOp>(originalAlloc)) {
178 memref::DeallocOp::create(builder, loc, memrefValue);
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
209 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType)
const {
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
220 auto memrefTy = memrefValue.
getType();
223 if (memrefTy.getRank() != 0)
226 return memref::LoadOp::create(builder, loc, memrefValue);
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
239 auto memrefTy = memrefValue.getType();
242 if (memrefTy.getRank() != 0)
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
249 bool isDeviceData(Type pointer, Value var)
const {
250 auto memrefTy = cast<T>(pointer);
251 Attribute memSpace = memrefTy.getMemorySpace();
252 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
256struct LLVMPointerPointerLikeModel
257 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
258 LLVM::LLVMPointerType> {
261 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
263 Type valueType)
const {
268 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
271 bool genStore(Type pointer, OpBuilder &builder, Location loc,
273 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
278struct MemrefAddressOfGlobalModel
279 :
public AddressOfGlobalOpInterface::ExternalModel<
280 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
281 SymbolRefAttr getSymbol(Operation *op)
const {
282 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
283 return getGlobalOp.getNameAttr();
287struct MemrefGlobalVariableModel
288 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
290 bool isConstant(Operation *op)
const {
291 auto globalOp = cast<memref::GlobalOp>(op);
292 return globalOp.getConstant();
295 Region *getInitRegion(Operation *op)
const {
300 bool isDeviceData(Operation *op)
const {
301 auto globalOp = cast<memref::GlobalOp>(op);
302 Attribute memSpace = globalOp.getType().getMemorySpace();
303 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
307struct GPULaunchOffloadRegionModel
308 :
public acc::OffloadRegionOpInterface::ExternalModel<
309 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
310 mlir::Region &getOffloadRegion(mlir::Operation *op)
const {
311 return cast<gpu::LaunchOp>(op).getBody();
319mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
320 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
323 if (existingDeviceTypes)
324 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
326 if (newDeviceTypes.empty())
327 deviceTypes.push_back(
328 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
330 for (DeviceType dt : newDeviceTypes)
331 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
333 return mlir::ArrayAttr::get(context, deviceTypes);
342mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
343 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
348 if (existingDeviceTypes)
349 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
351 if (newDeviceTypes.empty()) {
352 argCollection.
append(arguments);
353 segments.push_back(arguments.size());
354 deviceTypes.push_back(
355 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
358 for (DeviceType dt : newDeviceTypes) {
359 argCollection.
append(arguments);
360 segments.push_back(arguments.size());
361 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
364 return mlir::ArrayAttr::get(context, deviceTypes);
368mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
369 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
373 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
374 newDeviceTypes, arguments,
375 argCollection, segments);
383void OpenACCDialect::initialize() {
386#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
393#define GET_TYPEDEF_LIST
394#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
400 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
402 UnrankedMemRefType::attachInterface<
403 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
404 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
408 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
410 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
411 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*
getContext());
448void ParallelOp::getSuccessorRegions(
478void HostDataOp::getSuccessorRegions(
493 if (getUnstructured()) {
526 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
530 mlir::acc::DeviceType deviceType) {
534 for (
auto attr : *arrayAttr) {
535 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
536 if (deviceTypeAttr.getValue() == deviceType)
544 std::optional<mlir::ArrayAttr> deviceTypes) {
549 llvm::interleaveComma(*deviceTypes, p,
555 mlir::acc::DeviceType deviceType) {
556 unsigned segmentIdx = 0;
557 for (
auto attr : segments) {
558 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
559 if (deviceTypeAttr.getValue() == deviceType)
560 return std::make_optional(segmentIdx);
570 mlir::acc::DeviceType deviceType) {
572 return range.take_front(0);
573 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
574 int32_t nbOperandsBefore = 0;
575 for (
unsigned i = 0; i < *pos; ++i)
576 nbOperandsBefore += (*segments)[i];
577 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
579 return range.take_front(0);
586 std::optional<mlir::ArrayAttr> hasWaitDevnum,
587 mlir::acc::DeviceType deviceType) {
590 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
591 if (hasWaitDevnum->getValue()[*pos])
602 std::optional<mlir::ArrayAttr> hasWaitDevnum,
603 mlir::acc::DeviceType deviceType) {
608 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
609 if (hasWaitDevnum && *hasWaitDevnum) {
610 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
611 if (boolAttr.getValue())
612 return range.drop_front(1);
618template <
typename Op>
620 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
622 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
627 op.hasAsyncOnly(dtype))
629 "asyncOnly attribute cannot appear with asyncOperand");
634 op.hasWaitOnly(dtype))
635 return op.
emitError(
"wait attribute cannot appear with waitOperands");
640template <
typename Op>
643 return op.
emitError(
"must have var operand");
646 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
647 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
648 return op.
emitError(
"var must be mappable or pointer-like");
651 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
652 op.getVarType() == op.getVar().getType())
653 return op.
emitError(
"varType must capture the element type of var");
658template <
typename Op>
660 if (op.getVar().getType() != op.getAccVar().getType())
661 return op.
emitError(
"input and output types must match");
666template <
typename Op>
668 if (op.getModifiers() != acc::DataClauseModifier::none)
669 return op.
emitError(
"no data clause modifiers are allowed");
673template <
typename Op>
676 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
678 "invalid data clause modifiers: " +
679 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
684template <
typename OpT,
typename RecipeOpT>
685static LogicalResult
checkRecipe(OpT op, llvm::StringRef operandName) {
690 !std::is_same_v<OpT, acc::ReductionOp>)
693 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
695 return op->emitOpError() <<
"recipe expected for " << operandName;
700 return op->emitOpError()
701 <<
"expected symbol reference " << operandRecipe <<
" to point to a "
702 << operandName <<
" declaration";
723 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
744 if (failed(parser.
parseType(accVarType)))
754 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
766 mlir::TypeAttr &varTypeAttr) {
767 if (failed(parser.
parseType(varPtrType)))
778 varTypeAttr = mlir::TypeAttr::get(varType);
783 if (
auto ptrTy = dyn_cast<acc::PointerLikeType>(varPtrType)) {
784 Type elementType = ptrTy.getElementType();
787 varTypeAttr = mlir::TypeAttr::get(elementType ? elementType : varPtrType);
789 varTypeAttr = mlir::TypeAttr::get(varPtrType);
797 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
805 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
806 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
810 if (!typeToCheckAgainst)
811 typeToCheckAgainst = varPtrType;
812 if (typeToCheckAgainst != varType) {
820 mlir::SymbolRefAttr &recipeAttr) {
827 mlir::SymbolRefAttr recipeAttr) {
834LogicalResult acc::DataBoundsOp::verify() {
835 auto extent = getExtent();
836 auto upperbound = getUpperbound();
837 if (!extent && !upperbound)
838 return emitError(
"expected extent or upperbound.");
845LogicalResult acc::PrivateOp::verify() {
848 "data clause associated with private operation must match its intent");
862LogicalResult acc::FirstprivateOp::verify() {
864 return emitError(
"data clause associated with firstprivate operation must "
871 *
this,
"firstprivate")))
879LogicalResult acc::ReductionOp::verify() {
881 return emitError(
"data clause associated with reduction operation must "
888 *
this,
"reduction")))
896LogicalResult acc::DevicePtrOp::verify() {
898 return emitError(
"data clause associated with deviceptr operation must "
912LogicalResult acc::PresentOp::verify() {
915 "data clause associated with present operation must match its intent");
928LogicalResult acc::CopyinOp::verify() {
930 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
935 "data clause associated with copyin operation must match its intent"
936 " or specify original clause this operation was decomposed from");
942 acc::DataClauseModifier::always |
943 acc::DataClauseModifier::capture)))
948bool acc::CopyinOp::isCopyinReadonly() {
949 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
950 acc::bitEnumContainsAny(getModifiers(),
951 acc::DataClauseModifier::readonly);
957LogicalResult acc::CreateOp::verify() {
964 "data clause associated with create operation must match its intent"
965 " or specify original clause this operation was decomposed from");
973 acc::DataClauseModifier::always |
974 acc::DataClauseModifier::capture)))
979bool acc::CreateOp::isCreateZero() {
981 return getDataClause() == acc::DataClause::acc_create_zero ||
983 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
989LogicalResult acc::NoCreateOp::verify() {
991 return emitError(
"data clause associated with no_create operation must "
1005LogicalResult acc::AttachOp::verify() {
1008 "data clause associated with attach operation must match its intent");
1022LogicalResult acc::DeclareDeviceResidentOp::verify() {
1023 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
1024 return emitError(
"data clause associated with device_resident operation "
1025 "must match its intent");
1039LogicalResult acc::DeclareLinkOp::verify() {
1042 "data clause associated with link operation must match its intent");
1055LogicalResult acc::CopyoutOp::verify() {
1062 "data clause associated with copyout operation must match its intent"
1063 " or specify original clause this operation was decomposed from");
1065 return emitError(
"must have both host and device pointers");
1071 acc::DataClauseModifier::always |
1072 acc::DataClauseModifier::capture)))
1077bool acc::CopyoutOp::isCopyoutZero() {
1078 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1079 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1085LogicalResult acc::DeleteOp::verify() {
1094 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1097 "data clause associated with delete operation must match its intent"
1098 " or specify original clause this operation was decomposed from");
1100 return emitError(
"must have device pointer");
1104 acc::DataClauseModifier::readonly |
1105 acc::DataClauseModifier::always |
1106 acc::DataClauseModifier::capture)))
1114LogicalResult acc::DetachOp::verify() {
1119 "data clause associated with detach operation must match its intent"
1120 " or specify original clause this operation was decomposed from");
1122 return emitError(
"must have device pointer");
1131LogicalResult acc::UpdateHostOp::verify() {
1136 "data clause associated with host operation must match its intent"
1137 " or specify original clause this operation was decomposed from");
1139 return emitError(
"must have both host and device pointers");
1152LogicalResult acc::UpdateDeviceOp::verify() {
1156 "data clause associated with device operation must match its intent"
1157 " or specify original clause this operation was decomposed from");
1170LogicalResult acc::UseDeviceOp::verify() {
1174 "data clause associated with use_device operation must match its intent"
1175 " or specify original clause this operation was decomposed from");
1188LogicalResult acc::CacheOp::verify() {
1193 "data clause associated with cache operation must match its intent"
1194 " or specify original clause this operation was decomposed from");
1204bool acc::CacheOp::isCacheReadonly() {
1205 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1206 acc::bitEnumContainsAny(getModifiers(),
1207 acc::DataClauseModifier::readonly);
1223template <
typename EffectTy>
1228 for (
unsigned i = 0, e = operand.
size(); i < e; ++i)
1229 effects.emplace_back(EffectTy::get(), &operand[i]);
1233template <
typename EffectTy>
1238 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(
result));
1242void acc::PrivateOp::getEffects(
1256void acc::FirstprivateOp::getEffects(
1270void acc::ReductionOp::getEffects(
1284void acc::DevicePtrOp::getEffects(
1293void acc::PresentOp::getEffects(
1304void acc::CopyinOp::getEffects(
1317void acc::CreateOp::getEffects(
1330void acc::NoCreateOp::getEffects(
1341void acc::AttachOp::getEffects(
1354void acc::GetDevicePtrOp::getEffects(
1363void acc::UpdateDeviceOp::getEffects(
1373void acc::UseDeviceOp::getEffects(
1382void acc::DeclareDeviceResidentOp::getEffects(
1393void acc::DeclareLinkOp::getEffects(
1404void acc::CacheOp::getEffects(
1409void acc::CopyoutOp::getEffects(
1422void acc::DeleteOp::getEffects(
1434void acc::DetachOp::getEffects(
1446void acc::UpdateHostOp::getEffects(
1458template <
typename StructureOp>
1460 unsigned nRegions = 1) {
1463 for (
unsigned i = 0; i < nRegions; ++i)
1466 for (
Region *region : regions)
1477template <
typename OpTy>
1479 using OpRewritePattern<OpTy>::OpRewritePattern;
1481 LogicalResult matchAndRewrite(OpTy op,
1482 PatternRewriter &rewriter)
const override {
1484 Value ifCond = op.getIfCond();
1488 IntegerAttr constAttr;
1491 if (constAttr.getInt())
1492 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1504 assert(region.
hasOneBlock() &&
"expected single-block region");
1516template <
typename OpTy>
1517struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1518 using OpRewritePattern<OpTy>::OpRewritePattern;
1520 LogicalResult matchAndRewrite(OpTy op,
1521 PatternRewriter &rewriter)
const override {
1523 Value ifCond = op.getIfCond();
1527 IntegerAttr constAttr;
1530 if (constAttr.getInt())
1531 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1553 for (
Value bound : bounds) {
1554 argTypes.push_back(bound.getType());
1555 argLocs.push_back(loc);
1562 Value privatizedValue;
1568 if (isa<MappableType>(varType)) {
1569 auto mappableTy = cast<MappableType>(varType);
1570 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1571 privatizedValue = mappableTy.generatePrivateInit(
1572 builder, loc, typedVar, varName, bounds, {}, needsFree);
1573 if (!privatizedValue)
1576 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1577 auto pointerLikeTy = cast<PointerLikeType>(varType);
1579 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1580 blockArgVar, needsFree);
1581 if (!privatizedValue)
1586 acc::YieldOp::create(builder, loc, privatizedValue);
1601 for (
Value bound : bounds) {
1602 copyArgTypes.push_back(bound.getType());
1603 copyArgLocs.push_back(loc);
1610 bool isMappable = isa<MappableType>(varType);
1611 bool isPointerLike = isa<PointerLikeType>(varType);
1614 if (isMappable && !isPointerLike)
1618 if (isPointerLike) {
1619 auto pointerLikeTy = cast<PointerLikeType>(varType);
1624 if (!pointerLikeTy.genCopy(
1631 acc::TerminatorOp::create(builder, loc);
1645 for (
Value bound : bounds) {
1646 destroyArgTypes.push_back(bound.getType());
1647 destroyArgLocs.push_back(loc);
1651 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1655 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1656 if (isa<MappableType>(varType)) {
1657 auto mappableTy = cast<MappableType>(varType);
1658 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds))
1661 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1662 auto pointerLikeTy = cast<PointerLikeType>(varType);
1663 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1667 acc::TerminatorOp::create(builder, loc);
1678 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1680 if (optional && region.
empty())
1684 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1688 return op->
emitOpError() <<
"expects " << regionName
1691 << regionType <<
" type";
1694 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1695 if (yieldOp.getOperands().size() != 1 ||
1696 yieldOp.getOperands().getTypes()[0] != type)
1697 return op->
emitOpError() <<
"expects " << regionName
1699 "yield a value of the "
1700 << regionType <<
" type";
1706LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1708 "privatization",
"init",
getType(),
1712 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1718std::optional<PrivateRecipeOp>
1720 StringRef recipeName,
Type varType,
1723 bool isMappable = isa<MappableType>(varType);
1724 bool isPointerLike = isa<PointerLikeType>(varType);
1727 if (!isMappable && !isPointerLike)
1728 return std::nullopt;
1733 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1736 bool needsFree =
false;
1737 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1738 varName, bounds, needsFree))) {
1740 return std::nullopt;
1747 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1748 Value allocRes = yieldOp.getOperand(0);
1750 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1751 varType, allocRes, bounds))) {
1753 return std::nullopt;
1760std::optional<PrivateRecipeOp>
1762 StringRef recipeName,
1763 FirstprivateRecipeOp firstprivRecipe) {
1766 auto varType = firstprivRecipe.getType();
1767 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1771 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1774 if (!firstprivRecipe.getDestroyRegion().empty()) {
1776 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1786LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1788 "privatization",
"init",
getType(),
1792 if (getCopyRegion().empty())
1793 return emitOpError() <<
"expects non-empty copy region";
1798 return emitOpError() <<
"expects copy region with two arguments of the "
1799 "privatization type";
1801 if (getDestroyRegion().empty())
1805 "privatization",
"destroy",
1812std::optional<FirstprivateRecipeOp>
1814 StringRef recipeName,
Type varType,
1817 bool isMappable = isa<MappableType>(varType);
1818 bool isPointerLike = isa<PointerLikeType>(varType);
1821 if (!isMappable && !isPointerLike)
1822 return std::nullopt;
1827 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1830 bool needsFree =
false;
1831 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1832 varName, bounds, needsFree))) {
1834 return std::nullopt;
1838 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1841 return std::nullopt;
1848 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1849 Value allocRes = yieldOp.getOperand(0);
1851 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1852 varType, allocRes, bounds))) {
1854 return std::nullopt;
1865LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1871 if (getCombinerRegion().empty())
1872 return emitOpError() <<
"expects non-empty combiner region";
1874 Block &reductionBlock = getCombinerRegion().
front();
1878 return emitOpError() <<
"expects combiner region with the first two "
1879 <<
"arguments of the reduction type";
1881 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1882 if (yieldOp.getOperands().size() != 1 ||
1883 yieldOp.getOperands().getTypes()[0] !=
getType())
1884 return emitOpError() <<
"expects combiner region to yield a value "
1885 "of the reduction type";
1896template <
typename Op>
1900 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1901 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1902 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1903 operand.getDefiningOp()))
1905 "expect data entry/exit operation or acc.getdeviceptr "
1910template <
typename OpT,
typename RecipeOpT>
1913 llvm::StringRef operandName) {
1916 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1918 <<
"expected " << operandName <<
" as defining op";
1919 if (!set.insert(operand).second)
1921 << operandName <<
" operand appears more than once";
1926unsigned ParallelOp::getNumDataOperands() {
1927 return getReductionOperands().size() + getPrivateOperands().size() +
1928 getFirstprivateOperands().size() + getDataClauseOperands().size();
1931Value ParallelOp::getDataOperand(
unsigned i) {
1933 numOptional += getNumGangs().size();
1934 numOptional += getNumWorkers().size();
1935 numOptional += getVectorLength().size();
1936 numOptional += getIfCond() ? 1 : 0;
1937 numOptional += getSelfCond() ? 1 : 0;
1938 return getOperand(getWaitOperands().size() + numOptional + i);
1941template <
typename Op>
1944 llvm::StringRef keyword) {
1945 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1946 return op.
emitOpError() << keyword <<
" operands count must match "
1947 << keyword <<
" device_type count";
1951template <
typename Op>
1954 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1955 std::size_t numOperandsInSegments = 0;
1956 std::size_t nbOfSegments = 0;
1959 for (
auto segCount : segments.
asArrayRef()) {
1960 if (maxInSegment != 0 && segCount > maxInSegment)
1961 return op.
emitOpError() << keyword <<
" expects a maximum of "
1962 << maxInSegment <<
" values per segment";
1963 numOperandsInSegments += segCount;
1968 if ((numOperandsInSegments != operands.size()) ||
1969 (!deviceTypes && !operands.empty()))
1971 << keyword <<
" operand count does not match count in segments";
1972 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1974 << keyword <<
" segment count does not match device_type count";
1978LogicalResult acc::ParallelOp::verify() {
1980 mlir::acc::PrivateRecipeOp>(
1981 *
this, getPrivateOperands(),
"private")))
1984 mlir::acc::FirstprivateRecipeOp>(
1985 *
this, getFirstprivateOperands(),
"firstprivate")))
1988 mlir::acc::ReductionRecipeOp>(
1989 *
this, getReductionOperands(),
"reduction")))
1993 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1994 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1998 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1999 getWaitOperandsDeviceTypeAttr(),
"wait")))
2003 getNumWorkersDeviceTypeAttr(),
2008 getVectorLengthDeviceTypeAttr(),
2013 getAsyncOperandsDeviceTypeAttr(),
2026 mlir::acc::DeviceType deviceType) {
2029 if (
auto pos =
findSegment(*arrayAttr, deviceType))
2034bool acc::ParallelOp::hasAsyncOnly() {
2035 return hasAsyncOnly(mlir::acc::DeviceType::None);
2038bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2043 return getAsyncValue(mlir::acc::DeviceType::None);
2046mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2051mlir::Value acc::ParallelOp::getNumWorkersValue() {
2052 return getNumWorkersValue(mlir::acc::DeviceType::None);
2056acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2061mlir::Value acc::ParallelOp::getVectorLengthValue() {
2062 return getVectorLengthValue(mlir::acc::DeviceType::None);
2066acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2068 getVectorLength(), deviceType);
2072 return getNumGangsValues(mlir::acc::DeviceType::None);
2076ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2078 getNumGangsSegments(), deviceType);
2081bool acc::ParallelOp::hasWaitOnly() {
2082 return hasWaitOnly(mlir::acc::DeviceType::None);
2085bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2090 return getWaitValues(mlir::acc::DeviceType::None);
2094ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2096 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2097 getHasWaitDevnum(), deviceType);
2101 return getWaitDevnum(mlir::acc::DeviceType::None);
2104mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2106 getWaitOperandsSegments(), getHasWaitDevnum(),
2121 odsBuilder, odsState, asyncOperands,
nullptr,
2122 nullptr, waitOperands,
nullptr,
2124 nullptr, numGangs,
nullptr,
2125 nullptr, numWorkers,
2126 nullptr, vectorLength,
2127 nullptr, ifCond, selfCond,
2128 nullptr, reductionOperands, gangPrivateOperands,
2129 gangFirstPrivateOperands, dataClauseOperands,
2133void acc::ParallelOp::addNumWorkersOperand(
2136 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2137 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2138 getNumWorkersMutable()));
2140void acc::ParallelOp::addVectorLengthOperand(
2143 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2144 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2145 getVectorLengthMutable()));
2148void acc::ParallelOp::addAsyncOnly(
2150 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2151 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2154void acc::ParallelOp::addAsyncOperand(
2157 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2158 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2159 getAsyncOperandsMutable()));
2162void acc::ParallelOp::addNumGangsOperands(
2166 if (getNumGangsSegments())
2167 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2169 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2170 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2171 getNumGangsMutable(), segments));
2173 setNumGangsSegments(segments);
2175void acc::ParallelOp::addWaitOnly(
2177 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2178 effectiveDeviceTypes));
2180void acc::ParallelOp::addWaitOperands(
2185 if (getWaitOperandsSegments())
2186 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2188 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2189 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2190 getWaitOperandsMutable(), segments));
2191 setWaitOperandsSegments(segments);
2194 if (getHasWaitDevnumAttr())
2195 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2198 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2200 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2203void acc::ParallelOp::addPrivatization(
MLIRContext *context,
2204 mlir::acc::PrivateOp op,
2205 mlir::acc::PrivateRecipeOp recipe) {
2206 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2207 getPrivateOperandsMutable().append(op.getResult());
2210void acc::ParallelOp::addFirstPrivatization(
2211 MLIRContext *context, mlir::acc::FirstprivateOp op,
2212 mlir::acc::FirstprivateRecipeOp recipe) {
2213 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2214 getFirstprivateOperandsMutable().append(op.getResult());
2217void acc::ParallelOp::addReduction(
MLIRContext *context,
2218 mlir::acc::ReductionOp op,
2219 mlir::acc::ReductionRecipeOp recipe) {
2220 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2221 getReductionOperandsMutable().append(op.getResult());
2236 int32_t crtOperandsSize = operands.size();
2239 if (parser.parseOperand(operands.emplace_back()) ||
2240 parser.parseColonType(types.emplace_back()))
2245 seg.push_back(operands.size() - crtOperandsSize);
2255 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2256 parser.
getContext(), mlir::acc::DeviceType::None));
2262 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2269 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2270 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2271 p <<
" [" << attr <<
"]";
2276 std::optional<mlir::ArrayAttr> deviceTypes,
2277 std::optional<mlir::DenseI32ArrayAttr> segments) {
2279 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2281 llvm::interleaveComma(
2282 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2283 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2303 int32_t crtOperandsSize = operands.size();
2307 if (parser.parseOperand(operands.emplace_back()) ||
2308 parser.parseColonType(types.emplace_back()))
2314 seg.push_back(operands.size() - crtOperandsSize);
2324 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2325 parser.
getContext(), mlir::acc::DeviceType::None));
2331 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2340 std::optional<mlir::DenseI32ArrayAttr> segments) {
2342 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2344 llvm::interleaveComma(
2345 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2346 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2359 mlir::ArrayAttr &keywordOnly) {
2363 bool needCommaBeforeOperands =
false;
2367 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2368 parser.
getContext(), mlir::acc::DeviceType::None));
2369 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2376 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2383 needCommaBeforeOperands =
true;
2386 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2393 int32_t crtOperandsSize = operands.size();
2405 if (parser.parseOperand(operands.emplace_back()) ||
2406 parser.parseColonType(types.emplace_back()))
2412 seg.push_back(operands.size() - crtOperandsSize);
2422 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2423 parser.
getContext(), mlir::acc::DeviceType::None));
2430 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2431 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2433 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2441 if (attrs->size() != 1)
2443 if (
auto deviceTypeAttr =
2444 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2445 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2451 std::optional<mlir::ArrayAttr> deviceTypes,
2452 std::optional<mlir::DenseI32ArrayAttr> segments,
2453 std::optional<mlir::ArrayAttr> hasDevNum,
2454 std::optional<mlir::ArrayAttr> keywordOnly) {
2467 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2469 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2470 if (boolAttr && boolAttr.getValue())
2472 llvm::interleaveComma(
2473 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2474 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2491 if (parser.parseOperand(operands.emplace_back()) ||
2492 parser.parseColonType(types.emplace_back()))
2494 if (succeeded(parser.parseOptionalLSquare())) {
2495 if (parser.parseAttribute(attributes.emplace_back()) ||
2496 parser.parseRSquare())
2499 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2500 parser.getContext(), mlir::acc::DeviceType::None));
2507 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2514 std::optional<mlir::ArrayAttr> deviceTypes) {
2517 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2518 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2527 mlir::ArrayAttr &keywordOnlyDeviceType) {
2530 bool needCommaBeforeOperands =
false;
2534 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2535 parser.
getContext(), mlir::acc::DeviceType::None));
2536 keywordOnlyDeviceType =
2537 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2545 if (parser.parseAttribute(
2546 keywordOnlyDeviceTypeAttributes.emplace_back()))
2553 needCommaBeforeOperands =
true;
2556 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2561 if (parser.parseOperand(operands.emplace_back()) ||
2562 parser.parseColonType(types.emplace_back()))
2564 if (succeeded(parser.parseOptionalLSquare())) {
2565 if (parser.parseAttribute(attributes.emplace_back()) ||
2566 parser.parseRSquare())
2569 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2570 parser.getContext(), mlir::acc::DeviceType::None));
2576 if (
failed(parser.parseRParen()))
2581 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2588 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2590 if (operands.begin() == operands.end() &&
2606 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2607 mlir::Type &operandType, mlir::UnitAttr &attr) {
2610 attr = mlir::UnitAttr::get(parser.
getContext());
2620 if (failed(parser.
parseType(operandType)))
2630 std::optional<mlir::Value> operand,
2632 mlir::UnitAttr attr) {
2649 attr = mlir::UnitAttr::get(parser.
getContext());
2654 if (parser.parseOperand(operands.emplace_back()))
2662 if (parser.parseType(types.emplace_back()))
2677 mlir::UnitAttr attr) {
2682 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2684 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2690 mlir::acc::CombinedConstructsTypeAttr &attr) {
2692 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2693 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2695 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2696 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2698 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2699 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2702 "expected compute construct name");
2710 mlir::acc::CombinedConstructsTypeAttr attr) {
2712 switch (attr.getValue()) {
2713 case mlir::acc::CombinedConstructsType::KernelsLoop:
2716 case mlir::acc::CombinedConstructsType::ParallelLoop:
2719 case mlir::acc::CombinedConstructsType::SerialLoop:
2730unsigned SerialOp::getNumDataOperands() {
2731 return getReductionOperands().size() + getPrivateOperands().size() +
2732 getFirstprivateOperands().size() + getDataClauseOperands().size();
2735Value SerialOp::getDataOperand(
unsigned i) {
2737 numOptional += getIfCond() ? 1 : 0;
2738 numOptional += getSelfCond() ? 1 : 0;
2739 return getOperand(getWaitOperands().size() + numOptional + i);
2742bool acc::SerialOp::hasAsyncOnly() {
2743 return hasAsyncOnly(mlir::acc::DeviceType::None);
2746bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2751 return getAsyncValue(mlir::acc::DeviceType::None);
2754mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2759bool acc::SerialOp::hasWaitOnly() {
2760 return hasWaitOnly(mlir::acc::DeviceType::None);
2763bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2768 return getWaitValues(mlir::acc::DeviceType::None);
2772SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2774 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2775 getHasWaitDevnum(), deviceType);
2779 return getWaitDevnum(mlir::acc::DeviceType::None);
2782mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2784 getWaitOperandsSegments(), getHasWaitDevnum(),
2788LogicalResult acc::SerialOp::verify() {
2790 mlir::acc::PrivateRecipeOp>(
2791 *
this, getPrivateOperands(),
"private")))
2794 mlir::acc::FirstprivateRecipeOp>(
2795 *
this, getFirstprivateOperands(),
"firstprivate")))
2798 mlir::acc::ReductionRecipeOp>(
2799 *
this, getReductionOperands(),
"reduction")))
2803 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2804 getWaitOperandsDeviceTypeAttr(),
"wait")))
2808 getAsyncOperandsDeviceTypeAttr(),
2818void acc::SerialOp::addAsyncOnly(
2820 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2821 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2824void acc::SerialOp::addAsyncOperand(
2827 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2828 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2829 getAsyncOperandsMutable()));
2832void acc::SerialOp::addWaitOnly(
2834 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2835 effectiveDeviceTypes));
2837void acc::SerialOp::addWaitOperands(
2842 if (getWaitOperandsSegments())
2843 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2845 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2846 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2847 getWaitOperandsMutable(), segments));
2848 setWaitOperandsSegments(segments);
2851 if (getHasWaitDevnumAttr())
2852 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2855 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2857 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2860void acc::SerialOp::addPrivatization(
MLIRContext *context,
2861 mlir::acc::PrivateOp op,
2862 mlir::acc::PrivateRecipeOp recipe) {
2863 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2864 getPrivateOperandsMutable().append(op.getResult());
2867void acc::SerialOp::addFirstPrivatization(
2868 MLIRContext *context, mlir::acc::FirstprivateOp op,
2869 mlir::acc::FirstprivateRecipeOp recipe) {
2870 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2871 getFirstprivateOperandsMutable().append(op.getResult());
2874void acc::SerialOp::addReduction(
MLIRContext *context,
2875 mlir::acc::ReductionOp op,
2876 mlir::acc::ReductionRecipeOp recipe) {
2877 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2878 getReductionOperandsMutable().append(op.getResult());
2885unsigned KernelsOp::getNumDataOperands() {
2886 return getDataClauseOperands().size();
2889Value KernelsOp::getDataOperand(
unsigned i) {
2891 numOptional += getWaitOperands().size();
2892 numOptional += getNumGangs().size();
2893 numOptional += getNumWorkers().size();
2894 numOptional += getVectorLength().size();
2895 numOptional += getIfCond() ? 1 : 0;
2896 numOptional += getSelfCond() ? 1 : 0;
2897 return getOperand(numOptional + i);
2900bool acc::KernelsOp::hasAsyncOnly() {
2901 return hasAsyncOnly(mlir::acc::DeviceType::None);
2904bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2909 return getAsyncValue(mlir::acc::DeviceType::None);
2912mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2918 return getNumWorkersValue(mlir::acc::DeviceType::None);
2922acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2927mlir::Value acc::KernelsOp::getVectorLengthValue() {
2928 return getVectorLengthValue(mlir::acc::DeviceType::None);
2932acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2934 getVectorLength(), deviceType);
2938 return getNumGangsValues(mlir::acc::DeviceType::None);
2942KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2944 getNumGangsSegments(), deviceType);
2947bool acc::KernelsOp::hasWaitOnly() {
2948 return hasWaitOnly(mlir::acc::DeviceType::None);
2951bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2956 return getWaitValues(mlir::acc::DeviceType::None);
2960KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2962 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2963 getHasWaitDevnum(), deviceType);
2967 return getWaitDevnum(mlir::acc::DeviceType::None);
2970mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2972 getWaitOperandsSegments(), getHasWaitDevnum(),
2976LogicalResult acc::KernelsOp::verify() {
2978 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2979 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2983 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2984 getWaitOperandsDeviceTypeAttr(),
"wait")))
2988 getNumWorkersDeviceTypeAttr(),
2993 getVectorLengthDeviceTypeAttr(),
2998 getAsyncOperandsDeviceTypeAttr(),
3008void acc::KernelsOp::addPrivatization(
MLIRContext *context,
3009 mlir::acc::PrivateOp op,
3010 mlir::acc::PrivateRecipeOp recipe) {
3011 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3012 getPrivateOperandsMutable().append(op.getResult());
3015void acc::KernelsOp::addFirstPrivatization(
3016 MLIRContext *context, mlir::acc::FirstprivateOp op,
3017 mlir::acc::FirstprivateRecipeOp recipe) {
3018 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3019 getFirstprivateOperandsMutable().append(op.getResult());
3022void acc::KernelsOp::addReduction(
MLIRContext *context,
3023 mlir::acc::ReductionOp op,
3024 mlir::acc::ReductionRecipeOp recipe) {
3025 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3026 getReductionOperandsMutable().append(op.getResult());
3029void acc::KernelsOp::addNumWorkersOperand(
3032 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3033 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3034 getNumWorkersMutable()));
3037void acc::KernelsOp::addVectorLengthOperand(
3040 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3041 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3042 getVectorLengthMutable()));
3044void acc::KernelsOp::addAsyncOnly(
3046 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3047 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3050void acc::KernelsOp::addAsyncOperand(
3053 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3054 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3055 getAsyncOperandsMutable()));
3058void acc::KernelsOp::addNumGangsOperands(
3062 if (getNumGangsSegmentsAttr())
3063 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3065 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3066 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3067 getNumGangsMutable(), segments));
3069 setNumGangsSegments(segments);
3072void acc::KernelsOp::addWaitOnly(
3074 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3075 effectiveDeviceTypes));
3077void acc::KernelsOp::addWaitOperands(
3082 if (getWaitOperandsSegments())
3083 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3085 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3086 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3087 getWaitOperandsMutable(), segments));
3088 setWaitOperandsSegments(segments);
3091 if (getHasWaitDevnumAttr())
3092 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3095 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3097 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3104LogicalResult acc::HostDataOp::verify() {
3105 if (getDataClauseOperands().empty())
3106 return emitError(
"at least one operand must appear on the host_data "
3110 for (
mlir::Value operand : getDataClauseOperands()) {
3112 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3114 return emitError(
"expect data entry operation as defining op");
3117 if (!seenVars.insert(useDeviceOp.getVar()).second)
3118 return emitError(
"duplicate use_device variable");
3125 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3137 bool &needCommaBetweenValues,
bool &newValue) {
3144 attributes.push_back(gangArgType);
3145 needCommaBetweenValues =
true;
3156 mlir::ArrayAttr &gangOnlyDeviceType) {
3161 bool needCommaBetweenValues =
false;
3162 bool needCommaBeforeOperands =
false;
3166 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3167 parser.
getContext(), mlir::acc::DeviceType::None));
3168 gangOnlyDeviceType =
3169 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
3177 if (parser.parseAttribute(
3178 gangOnlyDeviceTypeAttributes.emplace_back()))
3185 needCommaBeforeOperands =
true;
3188 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3189 mlir::acc::GangArgType::Num);
3190 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3191 mlir::acc::GangArgType::Dim);
3192 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3193 parser.
getContext(), mlir::acc::GangArgType::Static);
3196 if (needCommaBeforeOperands) {
3197 needCommaBeforeOperands =
false;
3204 int32_t crtOperandsSize = gangOperands.size();
3206 bool newValue =
false;
3207 bool needValue =
false;
3208 if (needCommaBetweenValues) {
3216 gangOperands, gangOperandsType,
3217 gangArgTypeAttributes, argNum,
3218 needCommaBetweenValues, newValue)))
3221 gangOperands, gangOperandsType,
3222 gangArgTypeAttributes, argDim,
3223 needCommaBetweenValues, newValue)))
3225 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3226 gangOperands, gangOperandsType,
3227 gangArgTypeAttributes, argStatic,
3228 needCommaBetweenValues, newValue)))
3231 if (!newValue && needValue) {
3233 "new value expected after comma");
3241 if (gangOperands.empty())
3244 "expect at least one of num, dim or static values");
3250 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
3254 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3255 parser.
getContext(), mlir::acc::DeviceType::None));
3258 seg.push_back(gangOperands.size() - crtOperandsSize);
3266 gangArgTypeAttributes.end());
3267 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
3268 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
3271 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3272 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
3280 std::optional<mlir::ArrayAttr> gangArgTypes,
3281 std::optional<mlir::ArrayAttr> deviceTypes,
3282 std::optional<mlir::DenseI32ArrayAttr> segments,
3283 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3285 if (operands.begin() == operands.end() &&
3300 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
3302 llvm::interleaveComma(
3303 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
3304 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3305 (*gangArgTypes)[opIdx]);
3306 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3307 p << LoopOp::getGangNumKeyword();
3308 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3309 p << LoopOp::getGangDimKeyword();
3310 else if (gangArgTypeAttr.getValue() ==
3311 mlir::acc::GangArgType::Static)
3312 p << LoopOp::getGangStaticKeyword();
3313 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
3324 std::optional<mlir::ArrayAttr> segments,
3325 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3328 for (
auto attr : *segments) {
3329 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3330 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3338static std::optional<mlir::acc::DeviceType>
3340 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3342 return std::nullopt;
3343 for (
auto attr : deviceTypes) {
3344 auto deviceTypeAttr =
3345 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3346 if (!deviceTypeAttr)
3347 return mlir::acc::DeviceType::None;
3348 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3349 return deviceTypeAttr.getValue();
3351 return std::nullopt;
3354LogicalResult acc::LoopOp::verify() {
3355 if (getUpperbound().size() != getStep().size())
3356 return emitError() <<
"number of upperbounds expected to be the same as "
3359 if (getUpperbound().size() != getLowerbound().size())
3360 return emitError() <<
"number of upperbounds expected to be the same as "
3361 "number of lowerbounds";
3363 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3364 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3365 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3366 <<
" as upperbound size";
3369 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3370 return emitOpError() <<
"collapse device_type attr must be define when"
3371 <<
" collapse attr is present";
3373 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3374 getCollapseAttr().getValue().size() !=
3375 getCollapseDeviceTypeAttr().getValue().size())
3376 return emitOpError() <<
"collapse attribute count must match collapse"
3377 <<
" device_type count";
3378 if (
auto duplicateDeviceType =
checkDeviceTypes(getCollapseDeviceTypeAttr()))
3380 << acc::stringifyDeviceType(*duplicateDeviceType)
3381 <<
"` found in collapseDeviceType attribute";
3384 if (!getGangOperands().empty()) {
3385 if (!getGangOperandsArgType())
3386 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3387 <<
" when gang operands are present";
3389 if (getGangOperands().size() !=
3390 getGangOperandsArgTypeAttr().getValue().size())
3391 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3392 <<
" gangOperands count";
3394 if (getGangAttr()) {
3397 << acc::stringifyDeviceType(*duplicateDeviceType)
3398 <<
"` found in gang attribute";
3402 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3403 getGangOperandsDeviceTypeAttr(),
"gang")))
3409 << acc::stringifyDeviceType(*duplicateDeviceType)
3410 <<
"` found in worker attribute";
3411 if (
auto duplicateDeviceType =
3414 << acc::stringifyDeviceType(*duplicateDeviceType)
3415 <<
"` found in workerNumOperandsDeviceType attribute";
3417 getWorkerNumOperandsDeviceTypeAttr(),
3424 << acc::stringifyDeviceType(*duplicateDeviceType)
3425 <<
"` found in vector attribute";
3426 if (
auto duplicateDeviceType =
3429 << acc::stringifyDeviceType(*duplicateDeviceType)
3430 <<
"` found in vectorOperandsDeviceType attribute";
3432 getVectorOperandsDeviceTypeAttr(),
3437 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3438 getTileOperandsDeviceTypeAttr(),
"tile")))
3442 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3446 return emitError() <<
"only one of auto, independent, seq can be present "
3452 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3453 return attr.getValue() == mlir::acc::DeviceType::None;
3455 bool hasDefaultSeq =
3457 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3460 bool hasDefaultIndependent =
3461 getIndependentAttr()
3463 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3466 bool hasDefaultAuto =
3468 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3471 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3473 <<
"at least one of auto, independent, seq must be present";
3478 for (
auto attr : getSeqAttr()) {
3479 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3480 if (hasVector(deviceTypeAttr.getValue()) ||
3481 getVectorValue(deviceTypeAttr.getValue()) ||
3482 hasWorker(deviceTypeAttr.getValue()) ||
3483 getWorkerValue(deviceTypeAttr.getValue()) ||
3484 hasGang(deviceTypeAttr.getValue()) ||
3485 getGangValue(mlir::acc::GangArgType::Num,
3486 deviceTypeAttr.getValue()) ||
3487 getGangValue(mlir::acc::GangArgType::Dim,
3488 deviceTypeAttr.getValue()) ||
3489 getGangValue(mlir::acc::GangArgType::Static,
3490 deviceTypeAttr.getValue()))
3491 return emitError() <<
"gang, worker or vector cannot appear with seq";
3496 mlir::acc::PrivateRecipeOp>(
3497 *
this, getPrivateOperands(),
"private")))
3501 mlir::acc::FirstprivateRecipeOp>(
3502 *
this, getFirstprivateOperands(),
"firstprivate")))
3506 mlir::acc::ReductionRecipeOp>(
3507 *
this, getReductionOperands(),
"reduction")))
3510 if (getCombined().has_value() &&
3511 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3512 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3513 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3514 return emitError(
"unexpected combined constructs attribute");
3518 if (getRegion().empty())
3519 return emitError(
"expected non-empty body.");
3521 if (getUnstructured()) {
3522 if (!isContainerLike())
3524 "unstructured acc.loop must not have induction variables");
3525 }
else if (isContainerLike()) {
3529 uint64_t collapseCount = getCollapseValue().value_or(1);
3530 if (getCollapseAttr()) {
3531 for (
auto collapseEntry : getCollapseAttr()) {
3532 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3533 if (intAttr.getValue().getZExtValue() > collapseCount)
3534 collapseCount = intAttr.getValue().getZExtValue();
3542 bool foundSibling =
false;
3544 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3546 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3548 foundSibling =
true;
3553 expectedParent = op;
3556 if (collapseCount == 0)
3562 return emitError(
"found sibling loops inside container-like acc.loop");
3563 if (collapseCount != 0)
3564 return emitError(
"failed to find enough loop-like operations inside "
3565 "container-like acc.loop");
3571unsigned LoopOp::getNumDataOperands() {
3572 return getReductionOperands().size() + getPrivateOperands().size() +
3573 getFirstprivateOperands().size();
3576Value LoopOp::getDataOperand(
unsigned i) {
3577 unsigned numOptional =
3578 getLowerbound().size() + getUpperbound().size() + getStep().size();
3579 numOptional += getGangOperands().size();
3580 numOptional += getVectorOperands().size();
3581 numOptional += getWorkerNumOperands().size();
3582 numOptional += getTileOperands().size();
3583 numOptional += getCacheOperands().size();
3584 return getOperand(numOptional + i);
3587bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3589bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3593bool LoopOp::hasIndependent() {
3594 return hasIndependent(mlir::acc::DeviceType::None);
3597bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3601bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3603bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3608 return getVectorValue(mlir::acc::DeviceType::None);
3611mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3613 getVectorOperands(), deviceType);
3616bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3618bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3623 return getWorkerValue(mlir::acc::DeviceType::None);
3626mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3628 getWorkerNumOperands(), deviceType);
3631bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3633bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3638 return getTileValues(mlir::acc::DeviceType::None);
3642LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3644 getTileOperandsSegments(), deviceType);
3647std::optional<int64_t> LoopOp::getCollapseValue() {
3648 return getCollapseValue(mlir::acc::DeviceType::None);
3651std::optional<int64_t>
3652LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3653 if (!getCollapseAttr())
3654 return std::nullopt;
3655 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3657 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3658 return intAttr.getValue().getZExtValue();
3660 return std::nullopt;
3663mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3664 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3667mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3668 mlir::acc::DeviceType deviceType) {
3669 if (getGangOperands().empty())
3671 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3672 int32_t nbOperandsBefore = 0;
3673 for (
unsigned i = 0; i < *pos; ++i)
3674 nbOperandsBefore += (*getGangOperandsSegments())[i];
3677 .drop_front(nbOperandsBefore)
3678 .take_front((*getGangOperandsSegments())[*pos]);
3680 int32_t argTypeIdx = nbOperandsBefore;
3681 for (
auto value : values) {
3682 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3683 (*getGangOperandsArgType())[argTypeIdx]);
3684 if (gangArgTypeAttr.getValue() == gangArgType)
3692bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3694bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3699 return {&getRegion()};
3743 if (!regionArgs.empty()) {
3744 p << acc::LoopOp::getControlKeyword() <<
"(";
3745 llvm::interleaveComma(regionArgs, p,
3747 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3748 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3749 <<
" : " << stepType <<
") ";
3756 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3757 effectiveDeviceTypes));
3760void acc::LoopOp::addIndependent(
3762 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3763 context, getIndependentAttr(), effectiveDeviceTypes));
3768 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3769 effectiveDeviceTypes));
3772void acc::LoopOp::setCollapseForDeviceTypes(
3774 llvm::APInt value) {
3778 assert((getCollapseAttr() ==
nullptr) ==
3779 (getCollapseDeviceTypeAttr() ==
nullptr));
3780 assert(value.getBitWidth() == 64);
3782 if (getCollapseAttr()) {
3783 for (
const auto &existing :
3784 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3785 newValues.push_back(std::get<0>(existing));
3786 newDeviceTypes.push_back(std::get<1>(existing));
3790 if (effectiveDeviceTypes.empty()) {
3793 newValues.push_back(
3794 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3795 newDeviceTypes.push_back(
3796 acc::DeviceTypeAttr::get(context, DeviceType::None));
3798 for (DeviceType dt : effectiveDeviceTypes) {
3799 newValues.push_back(
3800 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3801 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3805 setCollapseAttr(ArrayAttr::get(context, newValues));
3806 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3809void acc::LoopOp::setTileForDeviceTypes(
3813 if (getTileOperandsSegments())
3814 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3816 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3817 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3818 getTileOperandsMutable(), segments));
3820 setTileOperandsSegments(segments);
3823void acc::LoopOp::addVectorOperand(
3826 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3827 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3828 newValue, getVectorOperandsMutable()));
3831void acc::LoopOp::addEmptyVector(
3833 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3834 effectiveDeviceTypes));
3837void acc::LoopOp::addWorkerNumOperand(
3840 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3841 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3842 newValue, getWorkerNumOperandsMutable()));
3845void acc::LoopOp::addEmptyWorker(
3847 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3848 effectiveDeviceTypes));
3851void acc::LoopOp::addEmptyGang(
3853 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3854 effectiveDeviceTypes));
3857bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3858 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3859 return attr.getValue() == dt;
3861 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3862 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3865 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3867 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3869 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3875bool acc::LoopOp::hasDefaultGangWorkerVector() {
3876 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3877 hasGang() || getGangValue(GangArgType::Num) ||
3878 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3882acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3883 if (hasSeq(deviceType))
3884 return LoopParMode::loop_seq;
3885 if (hasAuto(deviceType))
3886 return LoopParMode::loop_auto;
3887 if (hasIndependent(deviceType))
3888 return LoopParMode::loop_independent;
3890 return LoopParMode::loop_seq;
3892 return LoopParMode::loop_auto;
3893 assert(hasIndependent() &&
3894 "loop must have default auto, seq, or independent");
3895 return LoopParMode::loop_independent;
3898void acc::LoopOp::addGangOperands(
3903 getGangOperandsSegments())
3904 llvm::copy(*existingSegments, std::back_inserter(segments));
3906 unsigned beforeCount = segments.size();
3908 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3909 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3910 getGangOperandsMutable(), segments));
3912 setGangOperandsSegments(segments);
3919 unsigned numAdded = segments.size() - beforeCount;
3923 if (getGangOperandsArgTypeAttr())
3924 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3926 for (
auto i : llvm::index_range(0u, numAdded)) {
3927 llvm::transform(argTypes, std::back_inserter(gangTypes),
3928 [=](mlir::acc::GangArgType gangTy) {
3929 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3934 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3938void acc::LoopOp::addPrivatization(
MLIRContext *context,
3939 mlir::acc::PrivateOp op,
3940 mlir::acc::PrivateRecipeOp recipe) {
3941 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3942 getPrivateOperandsMutable().append(op.getResult());
3945void acc::LoopOp::addFirstPrivatization(
3946 MLIRContext *context, mlir::acc::FirstprivateOp op,
3947 mlir::acc::FirstprivateRecipeOp recipe) {
3948 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3949 getFirstprivateOperandsMutable().append(op.getResult());
3952void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3953 mlir::acc::ReductionRecipeOp recipe) {
3954 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3955 getReductionOperandsMutable().append(op.getResult());
3962LogicalResult acc::DataOp::verify() {
3967 return emitError(
"at least one operand or the default attribute "
3968 "must appear on the data operation");
3970 for (
mlir::Value operand : getDataClauseOperands())
3971 if (isa<BlockArgument>(operand) ||
3972 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3973 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3974 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3975 operand.getDefiningOp()))
3976 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3985unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3987Value DataOp::getDataOperand(
unsigned i) {
3988 unsigned numOptional = getIfCond() ? 1 : 0;
3990 numOptional += getWaitOperands().size();
3991 return getOperand(numOptional + i);
3994bool acc::DataOp::hasAsyncOnly() {
3995 return hasAsyncOnly(mlir::acc::DeviceType::None);
3998bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4003 return getAsyncValue(mlir::acc::DeviceType::None);
4006mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4011bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
4013bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4018 return getWaitValues(mlir::acc::DeviceType::None);
4022DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4024 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4025 getHasWaitDevnum(), deviceType);
4029 return getWaitDevnum(mlir::acc::DeviceType::None);
4032mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4034 getWaitOperandsSegments(), getHasWaitDevnum(),
4038void acc::DataOp::addAsyncOnly(
4040 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4041 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4044void acc::DataOp::addAsyncOperand(
4047 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4048 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4049 getAsyncOperandsMutable()));
4052void acc::DataOp::addWaitOnly(
MLIRContext *context,
4054 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4055 effectiveDeviceTypes));
4058void acc::DataOp::addWaitOperands(
4063 if (getWaitOperandsSegments())
4064 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4066 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4067 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4068 getWaitOperandsMutable(), segments));
4069 setWaitOperandsSegments(segments);
4072 if (getHasWaitDevnumAttr())
4073 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4076 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4078 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4085LogicalResult acc::ExitDataOp::verify() {
4089 if (getDataClauseOperands().empty())
4090 return emitError(
"at least one operand must be present in dataOperands on "
4091 "the exit data operation");
4095 if (getAsyncOperand() && getAsync())
4096 return emitError(
"async attribute cannot appear with asyncOperand");
4100 if (!getWaitOperands().empty() && getWait())
4101 return emitError(
"wait attribute cannot appear with waitOperands");
4103 if (getWaitDevnum() && getWaitOperands().empty())
4104 return emitError(
"wait_devnum cannot appear without waitOperands");
4109unsigned ExitDataOp::getNumDataOperands() {
4110 return getDataClauseOperands().size();
4113Value ExitDataOp::getDataOperand(
unsigned i) {
4114 unsigned numOptional = getIfCond() ? 1 : 0;
4115 numOptional += getAsyncOperand() ? 1 : 0;
4116 numOptional += getWaitDevnum() ? 1 : 0;
4117 return getOperand(getWaitOperands().size() + numOptional + i);
4122 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
4125void ExitDataOp::addAsyncOnly(
MLIRContext *context,
4127 assert(effectiveDeviceTypes.empty());
4128 assert(!getAsyncAttr());
4129 assert(!getAsyncOperand());
4131 setAsyncAttr(mlir::UnitAttr::get(context));
4134void ExitDataOp::addAsyncOperand(
4137 assert(effectiveDeviceTypes.empty());
4138 assert(!getAsyncAttr());
4139 assert(!getAsyncOperand());
4141 getAsyncOperandMutable().append(newValue);
4146 assert(effectiveDeviceTypes.empty());
4147 assert(!getWaitAttr());
4148 assert(getWaitOperands().empty());
4149 assert(!getWaitDevnum());
4151 setWaitAttr(mlir::UnitAttr::get(context));
4154void ExitDataOp::addWaitOperands(
4157 assert(effectiveDeviceTypes.empty());
4158 assert(!getWaitAttr());
4159 assert(getWaitOperands().empty());
4160 assert(!getWaitDevnum());
4165 getWaitDevnumMutable().append(newValues.front());
4166 newValues = newValues.drop_front();
4169 getWaitOperandsMutable().append(newValues);
4176LogicalResult acc::EnterDataOp::verify() {
4180 if (getDataClauseOperands().empty())
4181 return emitError(
"at least one operand must be present in dataOperands on "
4182 "the enter data operation");
4186 if (getAsyncOperand() && getAsync())
4187 return emitError(
"async attribute cannot appear with asyncOperand");
4191 if (!getWaitOperands().empty() && getWait())
4192 return emitError(
"wait attribute cannot appear with waitOperands");
4194 if (getWaitDevnum() && getWaitOperands().empty())
4195 return emitError(
"wait_devnum cannot appear without waitOperands");
4197 for (
mlir::Value operand : getDataClauseOperands())
4198 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4199 operand.getDefiningOp()))
4200 return emitError(
"expect data entry operation as defining op");
4205unsigned EnterDataOp::getNumDataOperands() {
4206 return getDataClauseOperands().size();
4209Value EnterDataOp::getDataOperand(
unsigned i) {
4210 unsigned numOptional = getIfCond() ? 1 : 0;
4211 numOptional += getAsyncOperand() ? 1 : 0;
4212 numOptional += getWaitDevnum() ? 1 : 0;
4213 return getOperand(getWaitOperands().size() + numOptional + i);
4218 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
4221void EnterDataOp::addAsyncOnly(
4223 assert(effectiveDeviceTypes.empty());
4224 assert(!getAsyncAttr());
4225 assert(!getAsyncOperand());
4227 setAsyncAttr(mlir::UnitAttr::get(context));
4230void EnterDataOp::addAsyncOperand(
4233 assert(effectiveDeviceTypes.empty());
4234 assert(!getAsyncAttr());
4235 assert(!getAsyncOperand());
4237 getAsyncOperandMutable().append(newValue);
4240void EnterDataOp::addWaitOnly(
MLIRContext *context,
4242 assert(effectiveDeviceTypes.empty());
4243 assert(!getWaitAttr());
4244 assert(getWaitOperands().empty());
4245 assert(!getWaitDevnum());
4247 setWaitAttr(mlir::UnitAttr::get(context));
4250void EnterDataOp::addWaitOperands(
4253 assert(effectiveDeviceTypes.empty());
4254 assert(!getWaitAttr());
4255 assert(getWaitOperands().empty());
4256 assert(!getWaitDevnum());
4261 getWaitDevnumMutable().append(newValues.front());
4262 newValues = newValues.drop_front();
4265 getWaitOperandsMutable().append(newValues);
4272LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
4278LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
4284LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4291 if (
Value writeVal = op.getWriteOpVal()) {
4300LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
4302LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4308AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4309 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4311 return dyn_cast<AtomicReadOp>(getSecondOp());
4314AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4315 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4317 return dyn_cast<AtomicWriteOp>(getSecondOp());
4320AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4321 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4323 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4326LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
4332template <
typename Op>
4335 bool requireAtLeastOneOperand =
true) {
4336 if (operands.empty() && requireAtLeastOneOperand)
4339 "at least one operand must appear on the declare operation");
4342 if (isa<BlockArgument>(operand) ||
4343 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4344 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4345 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4346 operand.getDefiningOp()))
4348 "expect valid declare data entry operation or acc.getdeviceptr "
4352 assert(var &&
"declare operands can only be data entry operations which "
4355 std::optional<mlir::acc::DataClause> dataClauseOptional{
4357 assert(dataClauseOptional.has_value() &&
4358 "declare operands can only be data entry operations which must have "
4360 (
void)dataClauseOptional;
4366LogicalResult acc::DeclareEnterOp::verify() {
4374LogicalResult acc::DeclareExitOp::verify() {
4385LogicalResult acc::DeclareOp::verify() {
4394 acc::DeviceType dtype) {
4395 unsigned parallelism = 0;
4396 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4397 parallelism += op.hasWorker(dtype) ? 1 : 0;
4398 parallelism += op.hasVector(dtype) ? 1 : 0;
4399 parallelism += op.hasSeq(dtype) ? 1 : 0;
4403LogicalResult acc::RoutineOp::verify() {
4404 unsigned baseParallelism =
4407 if (baseParallelism > 1)
4408 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4409 "be present at the same time";
4411 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4413 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4414 if (dtype == acc::DeviceType::None)
4418 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4419 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4420 "be present at the same time for device_type `"
4421 << acc::stringifyDeviceType(dtype) <<
"`";
4428 mlir::ArrayAttr &bindIdName,
4429 mlir::ArrayAttr &bindStrName,
4430 mlir::ArrayAttr &deviceIdTypes,
4431 mlir::ArrayAttr &deviceStrTypes) {
4438 mlir::Attribute newAttr;
4439 bool isSymbolRefAttr;
4440 auto parseResult = parser.parseAttribute(newAttr);
4441 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4442 bindIdNameAttrs.push_back(symbolRefAttr);
4443 isSymbolRefAttr = true;
4444 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4445 bindStrNameAttrs.push_back(stringAttr);
4446 isSymbolRefAttr =
false;
4451 if (isSymbolRefAttr) {
4452 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4453 parser.getContext(), mlir::acc::DeviceType::None));
4455 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4456 parser.getContext(), mlir::acc::DeviceType::None));
4459 if (isSymbolRefAttr) {
4460 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4461 parser.parseRSquare())
4464 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4465 parser.parseRSquare())
4473 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4474 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4475 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4476 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4482 std::optional<mlir::ArrayAttr> bindIdName,
4483 std::optional<mlir::ArrayAttr> bindStrName,
4484 std::optional<mlir::ArrayAttr> deviceIdTypes,
4485 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4492 allBindNames.append(bindIdName->begin(), bindIdName->end());
4493 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4498 allBindNames.append(bindStrName->begin(), bindStrName->end());
4499 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4503 if (!allBindNames.empty())
4504 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4505 [&](
const auto &pair) {
4506 p << std::get<0>(pair);
4512 mlir::ArrayAttr &gang,
4513 mlir::ArrayAttr &gangDim,
4514 mlir::ArrayAttr &gangDimDeviceTypes) {
4517 gangDimDeviceTypeAttrs;
4518 bool needCommaBeforeOperands =
false;
4522 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4523 parser.
getContext(), mlir::acc::DeviceType::None));
4524 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4531 if (parser.parseAttribute(gangAttrs.emplace_back()))
4538 needCommaBeforeOperands =
true;
4541 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4545 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4546 parser.parseColon() ||
4547 parser.parseAttribute(gangDimAttrs.emplace_back()))
4549 if (succeeded(parser.parseOptionalLSquare())) {
4550 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4551 parser.parseRSquare())
4554 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4555 parser.getContext(), mlir::acc::DeviceType::None));
4561 if (
failed(parser.parseRParen()))
4564 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4565 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4566 gangDimDeviceTypes =
4567 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4573 std::optional<mlir::ArrayAttr> gang,
4574 std::optional<mlir::ArrayAttr> gangDim,
4575 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4578 gang->size() == 1) {
4579 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4580 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4592 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4593 [&](
const auto &pair) {
4594 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4595 p << std::get<0>(pair);
4603 mlir::ArrayAttr &deviceTypes) {
4607 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4608 parser.
getContext(), mlir::acc::DeviceType::None));
4609 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4616 if (parser.parseAttribute(attributes.emplace_back()))
4624 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4630 std::optional<mlir::ArrayAttr> deviceTypes) {
4633 auto deviceTypeAttr =
4634 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4635 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4644 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4650bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4652bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4656bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4658bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4662bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4664bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4668std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4669RoutineOp::getBindNameValue() {
4670 return getBindNameValue(mlir::acc::DeviceType::None);
4673std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4674RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4677 return std::nullopt;
4680 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4681 auto attr = (*getBindIdName())[*pos];
4682 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4683 assert(symbolRefAttr &&
"expected SymbolRef");
4684 return symbolRefAttr;
4687 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4688 auto attr = (*getBindStrName())[*pos];
4689 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4690 assert(stringAttr &&
"expected String");
4694 return std::nullopt;
4697bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4699bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4703std::optional<int64_t> RoutineOp::getGangDimValue() {
4704 return getGangDimValue(mlir::acc::DeviceType::None);
4707std::optional<int64_t>
4708RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4710 return std::nullopt;
4711 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4712 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4713 return intAttr.getInt();
4715 return std::nullopt;
4720 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4721 effectiveDeviceTypes));
4726 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4727 effectiveDeviceTypes));
4732 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4733 effectiveDeviceTypes));
4738 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4739 effectiveDeviceTypes));
4748 if (getGangDimAttr())
4749 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4750 if (getGangDimDeviceTypeAttr())
4751 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4753 assert(dimValues.size() == deviceTypes.size());
4755 if (effectiveDeviceTypes.empty()) {
4756 dimValues.push_back(
4757 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4758 deviceTypes.push_back(
4759 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4761 for (DeviceType dt : effectiveDeviceTypes) {
4762 dimValues.push_back(
4763 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4764 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4767 assert(dimValues.size() == deviceTypes.size());
4769 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4770 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4773void RoutineOp::addBindStrName(
MLIRContext *context,
4775 mlir::StringAttr val) {
4776 unsigned before = getBindStrNameDeviceTypeAttr()
4777 ? getBindStrNameDeviceTypeAttr().size()
4780 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4781 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4782 unsigned after = getBindStrNameDeviceTypeAttr().size();
4785 if (getBindStrNameAttr())
4786 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4787 for (
unsigned i = 0; i < after - before; ++i)
4788 vals.push_back(val);
4790 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4793void RoutineOp::addBindIDName(
MLIRContext *context,
4795 mlir::SymbolRefAttr val) {
4797 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4799 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4800 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4801 unsigned after = getBindIdNameDeviceTypeAttr().size();
4804 if (getBindIdNameAttr())
4805 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4806 for (
unsigned i = 0; i < after - before; ++i)
4807 vals.push_back(val);
4809 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4816LogicalResult acc::InitOp::verify() {
4817 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4818 return emitOpError(
"cannot be nested in a compute operation");
4822void acc::InitOp::addDeviceType(
MLIRContext *context,
4823 mlir::acc::DeviceType deviceType) {
4825 if (getDeviceTypesAttr())
4826 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4828 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4829 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4836LogicalResult acc::ShutdownOp::verify() {
4837 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4838 return emitOpError(
"cannot be nested in a compute operation");
4842void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4843 mlir::acc::DeviceType deviceType) {
4845 if (getDeviceTypesAttr())
4846 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4848 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4849 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4856LogicalResult acc::SetOp::verify() {
4857 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4858 return emitOpError(
"cannot be nested in a compute operation");
4859 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4860 return emitOpError(
"at least one default_async, device_num, or device_type "
4861 "operand must appear");
4869LogicalResult acc::UpdateOp::verify() {
4871 if (getDataClauseOperands().empty())
4872 return emitError(
"at least one value must be present in dataOperands");
4875 getAsyncOperandsDeviceTypeAttr(),
4880 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4881 getWaitOperandsDeviceTypeAttr(),
"wait")))
4887 for (
mlir::Value operand : getDataClauseOperands())
4888 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4889 operand.getDefiningOp()))
4890 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4896unsigned UpdateOp::getNumDataOperands() {
4897 return getDataClauseOperands().size();
4900Value UpdateOp::getDataOperand(
unsigned i) {
4902 numOptional += getIfCond() ? 1 : 0;
4903 return getOperand(getWaitOperands().size() + numOptional + i);
4908 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
4911bool UpdateOp::hasAsyncOnly() {
4912 return hasAsyncOnly(mlir::acc::DeviceType::None);
4915bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4920 return getAsyncValue(mlir::acc::DeviceType::None);
4923mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4933bool UpdateOp::hasWaitOnly() {
4934 return hasWaitOnly(mlir::acc::DeviceType::None);
4937bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4942 return getWaitValues(mlir::acc::DeviceType::None);
4946UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4948 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4949 getHasWaitDevnum(), deviceType);
4953 return getWaitDevnum(mlir::acc::DeviceType::None);
4956mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4958 getWaitOperandsSegments(), getHasWaitDevnum(),
4964 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4965 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4968void UpdateOp::addAsyncOperand(
4971 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4972 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4973 getAsyncOperandsMutable()));
4978 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4979 effectiveDeviceTypes));
4982void UpdateOp::addWaitOperands(
4987 if (getWaitOperandsSegments())
4988 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4990 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4991 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4992 getWaitOperandsMutable(), segments));
4993 setWaitOperandsSegments(segments);
4996 if (getHasWaitDevnumAttr())
4997 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5000 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
5002 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5009LogicalResult acc::WaitOp::verify() {
5012 if (getAsyncOperand() && getAsync())
5013 return emitError(
"async attribute cannot appear with asyncOperand");
5015 if (getWaitDevnum() && getWaitOperands().empty())
5016 return emitError(
"wait_devnum cannot appear without waitOperands");
5021#define GET_OP_CLASSES
5022#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5024#define GET_ATTRDEF_CLASSES
5025#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5027#define GET_TYPEDEF_CLASSES
5028#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5039 .Case<ACC_DATA_ENTRY_OPS>(
5040 [&](
auto entry) {
return entry.getVarPtr(); })
5041 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5042 [&](
auto exit) {
return exit.getVarPtr(); })
5060 [&](
auto entry) {
return entry.getVarType(); })
5061 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5062 [&](
auto exit) {
return exit.getVarType(); })
5072 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5073 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
5083 [&](
auto dataClause) {
return dataClause.getAccVar(); })
5092 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
5102 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5104 dataClause.getBounds().begin(), dataClause.getBounds().end());
5116 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5118 dataClause.getAsyncOperands().begin(),
5119 dataClause.getAsyncOperands().end());
5130 return dataClause.getAsyncOperandsDeviceTypeAttr();
5138 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
5145 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
5152std::optional<mlir::acc::DataClause>
5157 .Case<ACC_DATA_ENTRY_OPS>(
5158 [&](
auto entry) {
return entry.getDataClause(); })
5166 [&](
auto entry) {
return entry.getImplicit(); })
5175 [&](
auto entry) {
return entry.getDataClauseOperands(); })
5177 return dataOperands;
5185 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
5187 return dataOperands;
5194 [&](
auto entry) {
return entry.getRecipeAttr(); })
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
unsigned size() const
Returns the current size of the range.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static CurrentDeviceIdResource * get()
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
#define ACC_COMPUTE_CONSTRUCT_OPS
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.