25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
41static bool isScalarLikeType(
Type type) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
56struct MemRefPointerLikeModel
57 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
59 return cast<T>(pointer).getElementType();
62 mlir::acc::VariableTypeCategory
65 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
72 return mlir::acc::VariableTypeCategory::uncategorized;
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
81 return mlir::acc::VariableTypeCategory::uncategorized;
85 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree)
const {
92 auto memrefTy = cast<MemRefType>(pointer);
96 if (memrefTy.hasStaticShape()) {
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
105 if (originalVar && originalVar.
getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
132 Type varType)
const {
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc =
nullptr;
143 while (currentValue) {
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
152 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
158 if (
auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
172 if (isa<memref::AllocaOp>(originalAlloc)) {
176 if (isa<memref::AllocOp>(originalAlloc)) {
178 memref::DeallocOp::create(builder, loc, memrefValue);
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
209 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType)
const {
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
220 auto memrefTy = memrefValue.
getType();
223 if (memrefTy.getRank() != 0)
226 return memref::LoadOp::create(builder, loc, memrefValue);
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
239 auto memrefTy = memrefValue.getType();
242 if (memrefTy.getRank() != 0)
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
249 bool isDeviceData(Type pointer, Value var)
const {
250 auto memrefTy = cast<T>(pointer);
251 Attribute memSpace = memrefTy.getMemorySpace();
252 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
256struct LLVMPointerPointerLikeModel
257 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
258 LLVM::LLVMPointerType> {
261 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
263 Type valueType)
const {
268 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
271 bool genStore(Type pointer, OpBuilder &builder, Location loc,
273 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
278struct MemrefAddressOfGlobalModel
279 :
public AddressOfGlobalOpInterface::ExternalModel<
280 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
281 SymbolRefAttr getSymbol(Operation *op)
const {
282 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
283 return getGlobalOp.getNameAttr();
287struct MemrefGlobalVariableModel
288 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
290 bool isConstant(Operation *op)
const {
291 auto globalOp = cast<memref::GlobalOp>(op);
292 return globalOp.getConstant();
295 Region *getInitRegion(Operation *op)
const {
300 bool isDeviceData(Operation *op)
const {
301 auto globalOp = cast<memref::GlobalOp>(op);
302 Attribute memSpace = globalOp.getType().getMemorySpace();
303 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
307struct GPULaunchOffloadRegionModel
308 :
public acc::OffloadRegionOpInterface::ExternalModel<
309 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
310 mlir::Region &getOffloadRegion(mlir::Operation *op)
const {
311 return cast<gpu::LaunchOp>(op).getBody();
319mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
320 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
323 if (existingDeviceTypes)
324 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
326 if (newDeviceTypes.empty())
327 deviceTypes.push_back(
328 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
330 for (DeviceType dt : newDeviceTypes)
331 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
333 return mlir::ArrayAttr::get(context, deviceTypes);
342mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
343 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
348 if (existingDeviceTypes)
349 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
351 if (newDeviceTypes.empty()) {
352 argCollection.
append(arguments);
353 segments.push_back(arguments.size());
354 deviceTypes.push_back(
355 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
358 for (DeviceType dt : newDeviceTypes) {
359 argCollection.
append(arguments);
360 segments.push_back(arguments.size());
361 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
364 return mlir::ArrayAttr::get(context, deviceTypes);
368mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
369 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
373 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
374 newDeviceTypes, arguments,
375 argCollection, segments);
383void OpenACCDialect::initialize() {
386#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
393#define GET_TYPEDEF_LIST
394#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
400 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
402 UnrankedMemRefType::attachInterface<
403 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
404 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
408 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
410 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
411 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*
getContext());
448void ParallelOp::getSuccessorRegions(
478void HostDataOp::getSuccessorRegions(
493 if (getUnstructured()) {
526 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
530 mlir::acc::DeviceType deviceType) {
534 for (
auto attr : *arrayAttr) {
535 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
536 if (deviceTypeAttr.getValue() == deviceType)
544 std::optional<mlir::ArrayAttr> deviceTypes) {
549 llvm::interleaveComma(*deviceTypes, p,
555 mlir::acc::DeviceType deviceType) {
556 unsigned segmentIdx = 0;
557 for (
auto attr : segments) {
558 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
559 if (deviceTypeAttr.getValue() == deviceType)
560 return std::make_optional(segmentIdx);
570 mlir::acc::DeviceType deviceType) {
572 return range.take_front(0);
573 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
574 int32_t nbOperandsBefore = 0;
575 for (
unsigned i = 0; i < *pos; ++i)
576 nbOperandsBefore += (*segments)[i];
577 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
579 return range.take_front(0);
586 std::optional<mlir::ArrayAttr> hasWaitDevnum,
587 mlir::acc::DeviceType deviceType) {
590 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
591 if (hasWaitDevnum->getValue()[*pos])
602 std::optional<mlir::ArrayAttr> hasWaitDevnum,
603 mlir::acc::DeviceType deviceType) {
608 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
609 if (hasWaitDevnum && *hasWaitDevnum) {
610 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
611 if (boolAttr.getValue())
612 return range.drop_front(1);
618template <
typename Op>
620 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
622 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
627 op.hasAsyncOnly(dtype))
629 "asyncOnly attribute cannot appear with asyncOperand");
634 op.hasWaitOnly(dtype))
635 return op.
emitError(
"wait attribute cannot appear with waitOperands");
640template <
typename Op>
643 return op.
emitError(
"must have var operand");
646 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
647 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
648 return op.
emitError(
"var must be mappable or pointer-like");
651 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
652 op.getVarType() == op.getVar().getType())
653 return op.
emitError(
"varType must capture the element type of var");
658template <
typename Op>
660 if (op.getVar().getType() != op.getAccVar().getType())
661 return op.
emitError(
"input and output types must match");
666template <
typename Op>
668 if (op.getModifiers() != acc::DataClauseModifier::none)
669 return op.
emitError(
"no data clause modifiers are allowed");
673template <
typename Op>
676 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
678 "invalid data clause modifiers: " +
679 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
684template <
typename OpT,
typename RecipeOpT>
685static LogicalResult
checkRecipe(OpT op, llvm::StringRef operandName) {
690 !std::is_same_v<OpT, acc::ReductionOp>)
693 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
695 return op->emitOpError() <<
"recipe expected for " << operandName;
700 return op->emitOpError()
701 <<
"expected symbol reference " << operandRecipe <<
" to point to a "
702 << operandName <<
" declaration";
723 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
744 if (failed(parser.
parseType(accVarType)))
754 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
766 mlir::TypeAttr &varTypeAttr) {
767 if (failed(parser.
parseType(varPtrType)))
778 varTypeAttr = mlir::TypeAttr::get(varType);
783 if (
auto ptrTy = dyn_cast<acc::PointerLikeType>(varPtrType)) {
784 Type elementType = ptrTy.getElementType();
787 varTypeAttr = mlir::TypeAttr::get(elementType ? elementType : varPtrType);
789 varTypeAttr = mlir::TypeAttr::get(varPtrType);
797 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
805 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
806 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
810 if (!typeToCheckAgainst)
811 typeToCheckAgainst = varPtrType;
812 if (typeToCheckAgainst != varType) {
820 mlir::SymbolRefAttr &recipeAttr) {
827 mlir::SymbolRefAttr recipeAttr) {
834LogicalResult acc::DataBoundsOp::verify() {
835 auto extent = getExtent();
836 auto upperbound = getUpperbound();
837 if (!extent && !upperbound)
838 return emitError(
"expected extent or upperbound.");
845LogicalResult acc::PrivateOp::verify() {
848 "data clause associated with private operation must match its intent");
862LogicalResult acc::FirstprivateOp::verify() {
864 return emitError(
"data clause associated with firstprivate operation must "
871 *
this,
"firstprivate")))
879LogicalResult acc::ReductionOp::verify() {
881 return emitError(
"data clause associated with reduction operation must "
888 *
this,
"reduction")))
896LogicalResult acc::DevicePtrOp::verify() {
898 return emitError(
"data clause associated with deviceptr operation must "
912LogicalResult acc::PresentOp::verify() {
915 "data clause associated with present operation must match its intent");
928LogicalResult acc::CopyinOp::verify() {
930 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
935 "data clause associated with copyin operation must match its intent"
936 " or specify original clause this operation was decomposed from");
942 acc::DataClauseModifier::always |
943 acc::DataClauseModifier::capture)))
948bool acc::CopyinOp::isCopyinReadonly() {
949 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
950 acc::bitEnumContainsAny(getModifiers(),
951 acc::DataClauseModifier::readonly);
957LogicalResult acc::CreateOp::verify() {
964 "data clause associated with create operation must match its intent"
965 " or specify original clause this operation was decomposed from");
973 acc::DataClauseModifier::always |
974 acc::DataClauseModifier::capture)))
979bool acc::CreateOp::isCreateZero() {
981 return getDataClause() == acc::DataClause::acc_create_zero ||
983 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
989LogicalResult acc::NoCreateOp::verify() {
991 return emitError(
"data clause associated with no_create operation must "
1005LogicalResult acc::AttachOp::verify() {
1008 "data clause associated with attach operation must match its intent");
1022LogicalResult acc::DeclareDeviceResidentOp::verify() {
1023 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
1024 return emitError(
"data clause associated with device_resident operation "
1025 "must match its intent");
1039LogicalResult acc::DeclareLinkOp::verify() {
1042 "data clause associated with link operation must match its intent");
1055LogicalResult acc::CopyoutOp::verify() {
1062 "data clause associated with copyout operation must match its intent"
1063 " or specify original clause this operation was decomposed from");
1065 return emitError(
"must have both host and device pointers");
1071 acc::DataClauseModifier::always |
1072 acc::DataClauseModifier::capture)))
1077bool acc::CopyoutOp::isCopyoutZero() {
1078 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1079 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1085LogicalResult acc::DeleteOp::verify() {
1094 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1097 "data clause associated with delete operation must match its intent"
1098 " or specify original clause this operation was decomposed from");
1100 return emitError(
"must have device pointer");
1104 acc::DataClauseModifier::readonly |
1105 acc::DataClauseModifier::always |
1106 acc::DataClauseModifier::capture)))
1114LogicalResult acc::DetachOp::verify() {
1119 "data clause associated with detach operation must match its intent"
1120 " or specify original clause this operation was decomposed from");
1122 return emitError(
"must have device pointer");
1131LogicalResult acc::UpdateHostOp::verify() {
1136 "data clause associated with host operation must match its intent"
1137 " or specify original clause this operation was decomposed from");
1139 return emitError(
"must have both host and device pointers");
1152LogicalResult acc::UpdateDeviceOp::verify() {
1156 "data clause associated with device operation must match its intent"
1157 " or specify original clause this operation was decomposed from");
1170LogicalResult acc::UseDeviceOp::verify() {
1174 "data clause associated with use_device operation must match its intent"
1175 " or specify original clause this operation was decomposed from");
1188LogicalResult acc::CacheOp::verify() {
1193 "data clause associated with cache operation must match its intent"
1194 " or specify original clause this operation was decomposed from");
1204bool acc::CacheOp::isCacheReadonly() {
1205 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1206 acc::bitEnumContainsAny(getModifiers(),
1207 acc::DataClauseModifier::readonly);
1223template <
typename EffectTy>
1228 for (
unsigned i = 0, e = operand.
size(); i < e; ++i)
1229 effects.emplace_back(EffectTy::get(), &operand[i]);
1233template <
typename EffectTy>
1238 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(
result));
1242void acc::PrivateOp::getEffects(
1256void acc::FirstprivateOp::getEffects(
1270void acc::ReductionOp::getEffects(
1284void acc::DevicePtrOp::getEffects(
1293void acc::PresentOp::getEffects(
1304void acc::CopyinOp::getEffects(
1317void acc::CreateOp::getEffects(
1330void acc::NoCreateOp::getEffects(
1341void acc::AttachOp::getEffects(
1354void acc::GetDevicePtrOp::getEffects(
1363void acc::UpdateDeviceOp::getEffects(
1373void acc::UseDeviceOp::getEffects(
1382void acc::DeclareDeviceResidentOp::getEffects(
1393void acc::DeclareLinkOp::getEffects(
1404void acc::CacheOp::getEffects(
1409void acc::CopyoutOp::getEffects(
1422void acc::DeleteOp::getEffects(
1434void acc::DetachOp::getEffects(
1446void acc::UpdateHostOp::getEffects(
1458template <
typename StructureOp>
1460 unsigned nRegions = 1) {
1463 for (
unsigned i = 0; i < nRegions; ++i)
1466 for (
Region *region : regions)
1477template <
typename OpTy>
1479 using OpRewritePattern<OpTy>::OpRewritePattern;
1481 LogicalResult matchAndRewrite(OpTy op,
1482 PatternRewriter &rewriter)
const override {
1484 Value ifCond = op.getIfCond();
1488 IntegerAttr constAttr;
1491 if (constAttr.getInt())
1492 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1504 assert(region.
hasOneBlock() &&
"expected single-block region");
1516template <
typename OpTy>
1517struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1518 using OpRewritePattern<OpTy>::OpRewritePattern;
1520 LogicalResult matchAndRewrite(OpTy op,
1521 PatternRewriter &rewriter)
const override {
1523 Value ifCond = op.getIfCond();
1527 IntegerAttr constAttr;
1530 if (constAttr.getInt())
1531 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1559 for (
Value bound : bounds) {
1560 argTypes.push_back(bound.getType());
1561 argLocs.push_back(loc);
1568 Value privatizedValue;
1574 if (isa<MappableType>(varType)) {
1575 auto mappableTy = cast<MappableType>(varType);
1576 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1577 auto typedHostVar = cast<TypedValue<MappableType>>(hostVar);
1578 varInfo = mappableTy.genPrivateVariableInfo(typedHostVar);
1579 privatizedValue = mappableTy.generatePrivateInit(
1580 builder, loc, typedVar, varName, bounds, {}, varInfo, needsFree);
1581 if (!privatizedValue)
1584 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1585 auto pointerLikeTy = cast<PointerLikeType>(varType);
1587 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1588 blockArgVar, needsFree);
1589 if (!privatizedValue)
1594 acc::YieldOp::create(builder, loc, privatizedValue);
1609 for (
Value bound : bounds) {
1610 copyArgTypes.push_back(bound.getType());
1611 copyArgLocs.push_back(loc);
1618 bool isMappable = isa<MappableType>(varType);
1619 bool isPointerLike = isa<PointerLikeType>(varType);
1622 if (isMappable && !isPointerLike)
1626 if (isPointerLike) {
1627 auto pointerLikeTy = cast<PointerLikeType>(varType);
1632 if (!pointerLikeTy.genCopy(
1639 acc::TerminatorOp::create(builder, loc);
1656 for (
Value bound : bounds) {
1657 destroyArgTypes.push_back(bound.getType());
1658 destroyArgLocs.push_back(loc);
1662 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1666 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1667 if (isa<MappableType>(varType)) {
1668 auto mappableTy = cast<MappableType>(varType);
1669 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds,
1673 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1674 auto pointerLikeTy = cast<PointerLikeType>(varType);
1675 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1679 acc::TerminatorOp::create(builder, loc);
1690 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1692 if (optional && region.
empty())
1696 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1700 return op->
emitOpError() <<
"expects " << regionName
1703 << regionType <<
" type";
1706 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1707 if (yieldOp.getOperands().size() != 1 ||
1708 yieldOp.getOperands().getTypes()[0] != type)
1709 return op->
emitOpError() <<
"expects " << regionName
1711 "yield a value of the "
1712 << regionType <<
" type";
1718LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1720 "privatization",
"init",
getType(),
1724 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1730std::optional<PrivateRecipeOp>
1732 StringRef recipeName,
Value hostVar,
1737 bool isMappable = isa<MappableType>(varType);
1738 bool isPointerLike = isa<PointerLikeType>(varType);
1741 if (!isMappable && !isPointerLike)
1742 return std::nullopt;
1747 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1750 bool needsFree =
false;
1752 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1753 varName, bounds, needsFree, varInfo))) {
1755 return std::nullopt;
1762 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1763 Value allocRes = yieldOp.getOperand(0);
1765 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1766 varType, allocRes, bounds, varInfo))) {
1768 return std::nullopt;
1775std::optional<PrivateRecipeOp>
1777 StringRef recipeName,
1778 FirstprivateRecipeOp firstprivRecipe) {
1781 auto varType = firstprivRecipe.getType();
1782 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1786 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1789 if (!firstprivRecipe.getDestroyRegion().empty()) {
1791 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1801LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1803 "privatization",
"init",
getType(),
1807 if (getCopyRegion().empty())
1808 return emitOpError() <<
"expects non-empty copy region";
1813 return emitOpError() <<
"expects copy region with two arguments of the "
1814 "privatization type";
1816 if (getDestroyRegion().empty())
1820 "privatization",
"destroy",
1827std::optional<FirstprivateRecipeOp>
1829 StringRef recipeName,
Value hostVar,
1834 bool isMappable = isa<MappableType>(varType);
1835 bool isPointerLike = isa<PointerLikeType>(varType);
1838 if (!isMappable && !isPointerLike)
1839 return std::nullopt;
1844 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1847 bool needsFree =
false;
1849 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1850 varName, bounds, needsFree, varInfo))) {
1852 return std::nullopt;
1856 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1859 return std::nullopt;
1866 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1867 Value allocRes = yieldOp.getOperand(0);
1869 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1870 varType, allocRes, bounds, varInfo))) {
1872 return std::nullopt;
1883LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1889 if (getCombinerRegion().empty())
1890 return emitOpError() <<
"expects non-empty combiner region";
1892 Block &reductionBlock = getCombinerRegion().
front();
1896 return emitOpError() <<
"expects combiner region with the first two "
1897 <<
"arguments of the reduction type";
1899 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1900 if (yieldOp.getOperands().size() != 1 ||
1901 yieldOp.getOperands().getTypes()[0] !=
getType())
1902 return emitOpError() <<
"expects combiner region to yield a value "
1903 "of the reduction type";
1914template <
typename Op>
1918 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1919 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1920 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1921 operand.getDefiningOp()))
1923 "expect data entry/exit operation or acc.getdeviceptr "
1928template <
typename OpT,
typename RecipeOpT>
1931 llvm::StringRef operandName) {
1934 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1936 <<
"expected " << operandName <<
" as defining op";
1937 if (!set.insert(operand).second)
1939 << operandName <<
" operand appears more than once";
1944unsigned ParallelOp::getNumDataOperands() {
1945 return getReductionOperands().size() + getPrivateOperands().size() +
1946 getFirstprivateOperands().size() + getDataClauseOperands().size();
1949Value ParallelOp::getDataOperand(
unsigned i) {
1951 numOptional += getNumGangs().size();
1952 numOptional += getNumWorkers().size();
1953 numOptional += getVectorLength().size();
1954 numOptional += getIfCond() ? 1 : 0;
1955 numOptional += getSelfCond() ? 1 : 0;
1956 return getOperand(getWaitOperands().size() + numOptional + i);
1959template <
typename Op>
1962 llvm::StringRef keyword) {
1963 if (!operands.empty() &&
1964 (!deviceTypes || deviceTypes.getValue().size() != operands.size()))
1965 return op.
emitOpError() << keyword <<
" operands count must match "
1966 << keyword <<
" device_type count";
1970template <
typename Op>
1973 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1974 std::size_t numOperandsInSegments = 0;
1975 std::size_t nbOfSegments = 0;
1978 for (
auto segCount : segments.
asArrayRef()) {
1979 if (maxInSegment != 0 && segCount > maxInSegment)
1980 return op.
emitOpError() << keyword <<
" expects a maximum of "
1981 << maxInSegment <<
" values per segment";
1982 numOperandsInSegments += segCount;
1987 if ((numOperandsInSegments != operands.size()) ||
1988 (!deviceTypes && !operands.empty()))
1990 << keyword <<
" operand count does not match count in segments";
1991 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1993 << keyword <<
" segment count does not match device_type count";
1997LogicalResult acc::ParallelOp::verify() {
1999 mlir::acc::PrivateRecipeOp>(
2000 *
this, getPrivateOperands(),
"private")))
2003 mlir::acc::FirstprivateRecipeOp>(
2004 *
this, getFirstprivateOperands(),
"firstprivate")))
2007 mlir::acc::ReductionRecipeOp>(
2008 *
this, getReductionOperands(),
"reduction")))
2012 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2013 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2017 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2018 getWaitOperandsDeviceTypeAttr(),
"wait")))
2022 getNumWorkersDeviceTypeAttr(),
2027 getVectorLengthDeviceTypeAttr(),
2032 getAsyncOperandsDeviceTypeAttr(),
2045 mlir::acc::DeviceType deviceType) {
2048 if (
auto pos =
findSegment(*arrayAttr, deviceType))
2053bool acc::ParallelOp::hasAsyncOnly() {
2054 return hasAsyncOnly(mlir::acc::DeviceType::None);
2057bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2062 return getAsyncValue(mlir::acc::DeviceType::None);
2065mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2070mlir::Value acc::ParallelOp::getNumWorkersValue() {
2071 return getNumWorkersValue(mlir::acc::DeviceType::None);
2075acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2080mlir::Value acc::ParallelOp::getVectorLengthValue() {
2081 return getVectorLengthValue(mlir::acc::DeviceType::None);
2085acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2087 getVectorLength(), deviceType);
2091 return getNumGangsValues(mlir::acc::DeviceType::None);
2095ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2097 getNumGangsSegments(), deviceType);
2100bool acc::ParallelOp::hasWaitOnly() {
2101 return hasWaitOnly(mlir::acc::DeviceType::None);
2104bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2109 return getWaitValues(mlir::acc::DeviceType::None);
2113ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2115 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2116 getHasWaitDevnum(), deviceType);
2120 return getWaitDevnum(mlir::acc::DeviceType::None);
2123mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2125 getWaitOperandsSegments(), getHasWaitDevnum(),
2140 odsBuilder, odsState, asyncOperands,
nullptr,
2141 nullptr, waitOperands,
nullptr,
2143 nullptr, numGangs,
nullptr,
2144 nullptr, numWorkers,
2145 nullptr, vectorLength,
2146 nullptr, ifCond, selfCond,
2147 nullptr, reductionOperands, gangPrivateOperands,
2148 gangFirstPrivateOperands, dataClauseOperands,
2152void acc::ParallelOp::addNumWorkersOperand(
2155 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2156 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2157 getNumWorkersMutable()));
2159void acc::ParallelOp::addVectorLengthOperand(
2162 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2163 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2164 getVectorLengthMutable()));
2167void acc::ParallelOp::addAsyncOnly(
2169 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2170 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2173void acc::ParallelOp::addAsyncOperand(
2176 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2177 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2178 getAsyncOperandsMutable()));
2181void acc::ParallelOp::addNumGangsOperands(
2185 if (getNumGangsSegments())
2186 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2188 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2189 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2190 getNumGangsMutable(), segments));
2192 setNumGangsSegments(segments);
2194void acc::ParallelOp::addWaitOnly(
2196 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2197 effectiveDeviceTypes));
2199void acc::ParallelOp::addWaitOperands(
2204 if (getWaitOperandsSegments())
2205 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2207 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2208 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2209 getWaitOperandsMutable(), segments));
2210 setWaitOperandsSegments(segments);
2213 if (getHasWaitDevnumAttr())
2214 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2217 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2219 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2222void acc::ParallelOp::addPrivatization(
MLIRContext *context,
2223 mlir::acc::PrivateOp op,
2224 mlir::acc::PrivateRecipeOp recipe) {
2225 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2226 getPrivateOperandsMutable().append(op.getResult());
2229void acc::ParallelOp::addFirstPrivatization(
2230 MLIRContext *context, mlir::acc::FirstprivateOp op,
2231 mlir::acc::FirstprivateRecipeOp recipe) {
2232 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2233 getFirstprivateOperandsMutable().append(op.getResult());
2236void acc::ParallelOp::addReduction(
MLIRContext *context,
2237 mlir::acc::ReductionOp op,
2238 mlir::acc::ReductionRecipeOp recipe) {
2239 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2240 getReductionOperandsMutable().append(op.getResult());
2255 int32_t crtOperandsSize = operands.size();
2258 if (parser.parseOperand(operands.emplace_back()) ||
2259 parser.parseColonType(types.emplace_back()))
2264 seg.push_back(operands.size() - crtOperandsSize);
2274 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2275 parser.
getContext(), mlir::acc::DeviceType::None));
2281 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2288 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2289 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2290 p <<
" [" << attr <<
"]";
2295 std::optional<mlir::ArrayAttr> deviceTypes,
2296 std::optional<mlir::DenseI32ArrayAttr> segments) {
2298 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2300 llvm::interleaveComma(
2301 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2302 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2322 int32_t crtOperandsSize = operands.size();
2326 if (parser.parseOperand(operands.emplace_back()) ||
2327 parser.parseColonType(types.emplace_back()))
2333 seg.push_back(operands.size() - crtOperandsSize);
2343 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2344 parser.
getContext(), mlir::acc::DeviceType::None));
2350 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2359 std::optional<mlir::DenseI32ArrayAttr> segments) {
2361 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2363 llvm::interleaveComma(
2364 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2365 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2378 mlir::ArrayAttr &keywordOnly) {
2382 bool needCommaBeforeOperands =
false;
2386 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2387 parser.
getContext(), mlir::acc::DeviceType::None));
2388 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2395 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2402 needCommaBeforeOperands =
true;
2405 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2412 int32_t crtOperandsSize = operands.size();
2424 if (parser.parseOperand(operands.emplace_back()) ||
2425 parser.parseColonType(types.emplace_back()))
2431 seg.push_back(operands.size() - crtOperandsSize);
2441 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2442 parser.
getContext(), mlir::acc::DeviceType::None));
2449 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2450 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2452 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2460 if (attrs->size() != 1)
2462 if (
auto deviceTypeAttr =
2463 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2464 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2470 std::optional<mlir::ArrayAttr> deviceTypes,
2471 std::optional<mlir::DenseI32ArrayAttr> segments,
2472 std::optional<mlir::ArrayAttr> hasDevNum,
2473 std::optional<mlir::ArrayAttr> keywordOnly) {
2486 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2488 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2489 if (boolAttr && boolAttr.getValue())
2491 llvm::interleaveComma(
2492 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2493 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2510 if (parser.parseOperand(operands.emplace_back()) ||
2511 parser.parseColonType(types.emplace_back()))
2513 if (succeeded(parser.parseOptionalLSquare())) {
2514 if (parser.parseAttribute(attributes.emplace_back()) ||
2515 parser.parseRSquare())
2518 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2519 parser.getContext(), mlir::acc::DeviceType::None));
2526 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2533 std::optional<mlir::ArrayAttr> deviceTypes) {
2536 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2537 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2546 mlir::ArrayAttr &keywordOnlyDeviceType) {
2549 bool needCommaBeforeOperands =
false;
2553 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2554 parser.
getContext(), mlir::acc::DeviceType::None));
2555 keywordOnlyDeviceType =
2556 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2564 if (parser.parseAttribute(
2565 keywordOnlyDeviceTypeAttributes.emplace_back()))
2572 needCommaBeforeOperands =
true;
2575 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2580 if (parser.parseOperand(operands.emplace_back()) ||
2581 parser.parseColonType(types.emplace_back()))
2583 if (succeeded(parser.parseOptionalLSquare())) {
2584 if (parser.parseAttribute(attributes.emplace_back()) ||
2585 parser.parseRSquare())
2588 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2589 parser.getContext(), mlir::acc::DeviceType::None));
2595 if (
failed(parser.parseRParen()))
2600 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2607 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2609 if (operands.begin() == operands.end() &&
2625 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2626 mlir::Type &operandType, mlir::UnitAttr &attr) {
2629 attr = mlir::UnitAttr::get(parser.
getContext());
2639 if (failed(parser.
parseType(operandType)))
2649 std::optional<mlir::Value> operand,
2651 mlir::UnitAttr attr) {
2668 attr = mlir::UnitAttr::get(parser.
getContext());
2673 if (parser.parseOperand(operands.emplace_back()))
2681 if (parser.parseType(types.emplace_back()))
2696 mlir::UnitAttr attr) {
2701 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2703 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2709 mlir::acc::CombinedConstructsTypeAttr &attr) {
2711 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2712 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2714 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2715 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2717 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2718 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2721 "expected compute construct name");
2729 mlir::acc::CombinedConstructsTypeAttr attr) {
2731 switch (attr.getValue()) {
2732 case mlir::acc::CombinedConstructsType::KernelsLoop:
2735 case mlir::acc::CombinedConstructsType::ParallelLoop:
2738 case mlir::acc::CombinedConstructsType::SerialLoop:
2749unsigned SerialOp::getNumDataOperands() {
2750 return getReductionOperands().size() + getPrivateOperands().size() +
2751 getFirstprivateOperands().size() + getDataClauseOperands().size();
2754Value SerialOp::getDataOperand(
unsigned i) {
2756 numOptional += getIfCond() ? 1 : 0;
2757 numOptional += getSelfCond() ? 1 : 0;
2758 return getOperand(getWaitOperands().size() + numOptional + i);
2761bool acc::SerialOp::hasAsyncOnly() {
2762 return hasAsyncOnly(mlir::acc::DeviceType::None);
2765bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2770 return getAsyncValue(mlir::acc::DeviceType::None);
2773mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2778bool acc::SerialOp::hasWaitOnly() {
2779 return hasWaitOnly(mlir::acc::DeviceType::None);
2782bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2787 return getWaitValues(mlir::acc::DeviceType::None);
2791SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2793 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2794 getHasWaitDevnum(), deviceType);
2798 return getWaitDevnum(mlir::acc::DeviceType::None);
2801mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2803 getWaitOperandsSegments(), getHasWaitDevnum(),
2807LogicalResult acc::SerialOp::verify() {
2809 mlir::acc::PrivateRecipeOp>(
2810 *
this, getPrivateOperands(),
"private")))
2813 mlir::acc::FirstprivateRecipeOp>(
2814 *
this, getFirstprivateOperands(),
"firstprivate")))
2817 mlir::acc::ReductionRecipeOp>(
2818 *
this, getReductionOperands(),
"reduction")))
2822 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2823 getWaitOperandsDeviceTypeAttr(),
"wait")))
2827 getAsyncOperandsDeviceTypeAttr(),
2837void acc::SerialOp::addAsyncOnly(
2839 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2840 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2843void acc::SerialOp::addAsyncOperand(
2846 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2847 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2848 getAsyncOperandsMutable()));
2851void acc::SerialOp::addWaitOnly(
2853 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2854 effectiveDeviceTypes));
2856void acc::SerialOp::addWaitOperands(
2861 if (getWaitOperandsSegments())
2862 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2864 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2865 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2866 getWaitOperandsMutable(), segments));
2867 setWaitOperandsSegments(segments);
2870 if (getHasWaitDevnumAttr())
2871 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2874 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2876 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2879void acc::SerialOp::addPrivatization(
MLIRContext *context,
2880 mlir::acc::PrivateOp op,
2881 mlir::acc::PrivateRecipeOp recipe) {
2882 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2883 getPrivateOperandsMutable().append(op.getResult());
2886void acc::SerialOp::addFirstPrivatization(
2887 MLIRContext *context, mlir::acc::FirstprivateOp op,
2888 mlir::acc::FirstprivateRecipeOp recipe) {
2889 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2890 getFirstprivateOperandsMutable().append(op.getResult());
2893void acc::SerialOp::addReduction(
MLIRContext *context,
2894 mlir::acc::ReductionOp op,
2895 mlir::acc::ReductionRecipeOp recipe) {
2896 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2897 getReductionOperandsMutable().append(op.getResult());
2904unsigned KernelsOp::getNumDataOperands() {
2905 return getDataClauseOperands().size();
2908Value KernelsOp::getDataOperand(
unsigned i) {
2910 numOptional += getWaitOperands().size();
2911 numOptional += getNumGangs().size();
2912 numOptional += getNumWorkers().size();
2913 numOptional += getVectorLength().size();
2914 numOptional += getIfCond() ? 1 : 0;
2915 numOptional += getSelfCond() ? 1 : 0;
2916 return getOperand(numOptional + i);
2919bool acc::KernelsOp::hasAsyncOnly() {
2920 return hasAsyncOnly(mlir::acc::DeviceType::None);
2923bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2928 return getAsyncValue(mlir::acc::DeviceType::None);
2931mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2937 return getNumWorkersValue(mlir::acc::DeviceType::None);
2941acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2946mlir::Value acc::KernelsOp::getVectorLengthValue() {
2947 return getVectorLengthValue(mlir::acc::DeviceType::None);
2951acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2953 getVectorLength(), deviceType);
2957 return getNumGangsValues(mlir::acc::DeviceType::None);
2961KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2963 getNumGangsSegments(), deviceType);
2966bool acc::KernelsOp::hasWaitOnly() {
2967 return hasWaitOnly(mlir::acc::DeviceType::None);
2970bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2975 return getWaitValues(mlir::acc::DeviceType::None);
2979KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2981 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2982 getHasWaitDevnum(), deviceType);
2986 return getWaitDevnum(mlir::acc::DeviceType::None);
2989mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2991 getWaitOperandsSegments(), getHasWaitDevnum(),
2995LogicalResult acc::KernelsOp::verify() {
2997 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2998 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
3002 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3003 getWaitOperandsDeviceTypeAttr(),
"wait")))
3007 getNumWorkersDeviceTypeAttr(),
3012 getVectorLengthDeviceTypeAttr(),
3017 getAsyncOperandsDeviceTypeAttr(),
3027void acc::KernelsOp::addPrivatization(
MLIRContext *context,
3028 mlir::acc::PrivateOp op,
3029 mlir::acc::PrivateRecipeOp recipe) {
3030 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3031 getPrivateOperandsMutable().append(op.getResult());
3034void acc::KernelsOp::addFirstPrivatization(
3035 MLIRContext *context, mlir::acc::FirstprivateOp op,
3036 mlir::acc::FirstprivateRecipeOp recipe) {
3037 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3038 getFirstprivateOperandsMutable().append(op.getResult());
3041void acc::KernelsOp::addReduction(
MLIRContext *context,
3042 mlir::acc::ReductionOp op,
3043 mlir::acc::ReductionRecipeOp recipe) {
3044 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3045 getReductionOperandsMutable().append(op.getResult());
3048void acc::KernelsOp::addNumWorkersOperand(
3051 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3052 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3053 getNumWorkersMutable()));
3056void acc::KernelsOp::addVectorLengthOperand(
3059 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3060 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3061 getVectorLengthMutable()));
3063void acc::KernelsOp::addAsyncOnly(
3065 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3066 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3069void acc::KernelsOp::addAsyncOperand(
3072 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3073 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3074 getAsyncOperandsMutable()));
3077void acc::KernelsOp::addNumGangsOperands(
3081 if (getNumGangsSegmentsAttr())
3082 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3084 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3085 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3086 getNumGangsMutable(), segments));
3088 setNumGangsSegments(segments);
3091void acc::KernelsOp::addWaitOnly(
3093 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3094 effectiveDeviceTypes));
3096void acc::KernelsOp::addWaitOperands(
3101 if (getWaitOperandsSegments())
3102 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3104 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3105 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3106 getWaitOperandsMutable(), segments));
3107 setWaitOperandsSegments(segments);
3110 if (getHasWaitDevnumAttr())
3111 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3114 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3116 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3123LogicalResult acc::HostDataOp::verify() {
3124 if (getDataClauseOperands().empty())
3125 return emitError(
"at least one operand must appear on the host_data "
3129 for (
mlir::Value operand : getDataClauseOperands()) {
3131 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3133 return emitError(
"expect data entry operation as defining op");
3136 if (!seenVars.insert(useDeviceOp.getVar()).second)
3137 return emitError(
"duplicate use_device variable");
3144 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3156 bool &needCommaBetweenValues,
bool &newValue) {
3163 attributes.push_back(gangArgType);
3164 needCommaBetweenValues =
true;
3175 mlir::ArrayAttr &gangOnlyDeviceType) {
3180 bool needCommaBetweenValues =
false;
3181 bool needCommaBeforeOperands =
false;
3185 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3186 parser.
getContext(), mlir::acc::DeviceType::None));
3187 gangOnlyDeviceType =
3188 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
3196 if (parser.parseAttribute(
3197 gangOnlyDeviceTypeAttributes.emplace_back()))
3204 needCommaBeforeOperands =
true;
3207 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3208 mlir::acc::GangArgType::Num);
3209 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3210 mlir::acc::GangArgType::Dim);
3211 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3212 parser.
getContext(), mlir::acc::GangArgType::Static);
3215 if (needCommaBeforeOperands) {
3216 needCommaBeforeOperands =
false;
3223 int32_t crtOperandsSize = gangOperands.size();
3225 bool newValue =
false;
3226 bool needValue =
false;
3227 if (needCommaBetweenValues) {
3235 gangOperands, gangOperandsType,
3236 gangArgTypeAttributes, argNum,
3237 needCommaBetweenValues, newValue)))
3240 gangOperands, gangOperandsType,
3241 gangArgTypeAttributes, argDim,
3242 needCommaBetweenValues, newValue)))
3244 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3245 gangOperands, gangOperandsType,
3246 gangArgTypeAttributes, argStatic,
3247 needCommaBetweenValues, newValue)))
3250 if (!newValue && needValue) {
3252 "new value expected after comma");
3260 if (gangOperands.empty())
3263 "expect at least one of num, dim or static values");
3269 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
3273 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3274 parser.
getContext(), mlir::acc::DeviceType::None));
3277 seg.push_back(gangOperands.size() - crtOperandsSize);
3285 gangArgTypeAttributes.end());
3286 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
3287 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
3290 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3291 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
3299 std::optional<mlir::ArrayAttr> gangArgTypes,
3300 std::optional<mlir::ArrayAttr> deviceTypes,
3301 std::optional<mlir::DenseI32ArrayAttr> segments,
3302 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3304 if (operands.begin() == operands.end() &&
3319 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
3321 llvm::interleaveComma(
3322 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
3323 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3324 (*gangArgTypes)[opIdx]);
3325 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3326 p << LoopOp::getGangNumKeyword();
3327 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3328 p << LoopOp::getGangDimKeyword();
3329 else if (gangArgTypeAttr.getValue() ==
3330 mlir::acc::GangArgType::Static)
3331 p << LoopOp::getGangStaticKeyword();
3332 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
3343 std::optional<mlir::ArrayAttr> segments,
3344 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3347 for (
auto attr : *segments) {
3348 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3349 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3357static std::optional<mlir::acc::DeviceType>
3359 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3361 return std::nullopt;
3362 for (
auto attr : deviceTypes) {
3363 auto deviceTypeAttr =
3364 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3365 if (!deviceTypeAttr)
3366 return mlir::acc::DeviceType::None;
3367 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3368 return deviceTypeAttr.getValue();
3370 return std::nullopt;
3373LogicalResult acc::LoopOp::verify() {
3374 if (getUpperbound().size() != getStep().size())
3375 return emitError() <<
"number of upperbounds expected to be the same as "
3378 if (getUpperbound().size() != getLowerbound().size())
3379 return emitError() <<
"number of upperbounds expected to be the same as "
3380 "number of lowerbounds";
3382 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3383 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3384 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3385 <<
" as upperbound size";
3388 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3389 return emitOpError() <<
"collapse device_type attr must be define when"
3390 <<
" collapse attr is present";
3392 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3393 getCollapseAttr().getValue().size() !=
3394 getCollapseDeviceTypeAttr().getValue().size())
3395 return emitOpError() <<
"collapse attribute count must match collapse"
3396 <<
" device_type count";
3397 if (
auto duplicateDeviceType =
checkDeviceTypes(getCollapseDeviceTypeAttr()))
3399 << acc::stringifyDeviceType(*duplicateDeviceType)
3400 <<
"` found in collapseDeviceType attribute";
3403 if (!getGangOperands().empty()) {
3404 if (!getGangOperandsArgType())
3405 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3406 <<
" when gang operands are present";
3408 if (getGangOperands().size() !=
3409 getGangOperandsArgTypeAttr().getValue().size())
3410 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3411 <<
" gangOperands count";
3413 if (getGangAttr()) {
3416 << acc::stringifyDeviceType(*duplicateDeviceType)
3417 <<
"` found in gang attribute";
3421 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3422 getGangOperandsDeviceTypeAttr(),
"gang")))
3428 << acc::stringifyDeviceType(*duplicateDeviceType)
3429 <<
"` found in worker attribute";
3430 if (
auto duplicateDeviceType =
3433 << acc::stringifyDeviceType(*duplicateDeviceType)
3434 <<
"` found in workerNumOperandsDeviceType attribute";
3436 getWorkerNumOperandsDeviceTypeAttr(),
3443 << acc::stringifyDeviceType(*duplicateDeviceType)
3444 <<
"` found in vector attribute";
3445 if (
auto duplicateDeviceType =
3448 << acc::stringifyDeviceType(*duplicateDeviceType)
3449 <<
"` found in vectorOperandsDeviceType attribute";
3451 getVectorOperandsDeviceTypeAttr(),
3456 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3457 getTileOperandsDeviceTypeAttr(),
"tile")))
3461 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3465 return emitError() <<
"only one of auto, independent, seq can be present "
3471 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3472 return attr.getValue() == mlir::acc::DeviceType::None;
3474 bool hasDefaultSeq =
3476 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3479 bool hasDefaultIndependent =
3480 getIndependentAttr()
3482 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3485 bool hasDefaultAuto =
3487 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3490 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3492 <<
"at least one of auto, independent, seq must be present";
3497 for (
auto attr : getSeqAttr()) {
3498 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3499 if (hasVector(deviceTypeAttr.getValue()) ||
3500 getVectorValue(deviceTypeAttr.getValue()) ||
3501 hasWorker(deviceTypeAttr.getValue()) ||
3502 getWorkerValue(deviceTypeAttr.getValue()) ||
3503 hasGang(deviceTypeAttr.getValue()) ||
3504 getGangValue(mlir::acc::GangArgType::Num,
3505 deviceTypeAttr.getValue()) ||
3506 getGangValue(mlir::acc::GangArgType::Dim,
3507 deviceTypeAttr.getValue()) ||
3508 getGangValue(mlir::acc::GangArgType::Static,
3509 deviceTypeAttr.getValue()))
3510 return emitError() <<
"gang, worker or vector cannot appear with seq";
3515 mlir::acc::PrivateRecipeOp>(
3516 *
this, getPrivateOperands(),
"private")))
3520 mlir::acc::FirstprivateRecipeOp>(
3521 *
this, getFirstprivateOperands(),
"firstprivate")))
3525 mlir::acc::ReductionRecipeOp>(
3526 *
this, getReductionOperands(),
"reduction")))
3529 if (getCombined().has_value() &&
3530 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3531 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3532 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3533 return emitError(
"unexpected combined constructs attribute");
3537 if (getRegion().empty())
3538 return emitError(
"expected non-empty body.");
3540 if (getUnstructured()) {
3541 if (!isContainerLike())
3543 "unstructured acc.loop must not have induction variables");
3544 }
else if (isContainerLike()) {
3548 uint64_t collapseCount = getCollapseValue().value_or(1);
3549 if (getCollapseAttr()) {
3550 for (
auto collapseEntry : getCollapseAttr()) {
3551 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3552 if (intAttr.getValue().getZExtValue() > collapseCount)
3553 collapseCount = intAttr.getValue().getZExtValue();
3561 bool foundSibling =
false;
3563 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3565 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3567 foundSibling =
true;
3572 expectedParent = op;
3575 if (collapseCount == 0)
3581 return emitError(
"found sibling loops inside container-like acc.loop");
3582 if (collapseCount != 0)
3583 return emitError(
"failed to find enough loop-like operations inside "
3584 "container-like acc.loop");
3590unsigned LoopOp::getNumDataOperands() {
3591 return getReductionOperands().size() + getPrivateOperands().size() +
3592 getFirstprivateOperands().size();
3595Value LoopOp::getDataOperand(
unsigned i) {
3596 unsigned numOptional =
3597 getLowerbound().size() + getUpperbound().size() + getStep().size();
3598 numOptional += getGangOperands().size();
3599 numOptional += getVectorOperands().size();
3600 numOptional += getWorkerNumOperands().size();
3601 numOptional += getTileOperands().size();
3602 numOptional += getCacheOperands().size();
3603 return getOperand(numOptional + i);
3606bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3608bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3612bool LoopOp::hasIndependent() {
3613 return hasIndependent(mlir::acc::DeviceType::None);
3616bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3620bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3622bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3627 return getVectorValue(mlir::acc::DeviceType::None);
3630mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3632 getVectorOperands(), deviceType);
3635bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3637bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3642 return getWorkerValue(mlir::acc::DeviceType::None);
3645mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3647 getWorkerNumOperands(), deviceType);
3650bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3652bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3657 return getTileValues(mlir::acc::DeviceType::None);
3661LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3663 getTileOperandsSegments(), deviceType);
3666std::optional<int64_t> LoopOp::getCollapseValue() {
3667 return getCollapseValue(mlir::acc::DeviceType::None);
3670std::optional<int64_t>
3671LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3672 if (!getCollapseAttr())
3673 return std::nullopt;
3674 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3676 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3677 return intAttr.getValue().getZExtValue();
3679 return std::nullopt;
3682mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3683 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3686mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3687 mlir::acc::DeviceType deviceType) {
3688 if (getGangOperands().empty())
3690 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3691 int32_t nbOperandsBefore = 0;
3692 for (
unsigned i = 0; i < *pos; ++i)
3693 nbOperandsBefore += (*getGangOperandsSegments())[i];
3696 .drop_front(nbOperandsBefore)
3697 .take_front((*getGangOperandsSegments())[*pos]);
3699 int32_t argTypeIdx = nbOperandsBefore;
3700 for (
auto value : values) {
3701 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3702 (*getGangOperandsArgType())[argTypeIdx]);
3703 if (gangArgTypeAttr.getValue() == gangArgType)
3711bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3713bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3718 return {&getRegion()};
3762 if (!regionArgs.empty()) {
3763 p << acc::LoopOp::getControlKeyword() <<
"(";
3764 llvm::interleaveComma(regionArgs, p,
3766 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3767 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3768 <<
" : " << stepType <<
") ";
3775 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3776 effectiveDeviceTypes));
3779void acc::LoopOp::addIndependent(
3781 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3782 context, getIndependentAttr(), effectiveDeviceTypes));
3787 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3788 effectiveDeviceTypes));
3791void acc::LoopOp::setCollapseForDeviceTypes(
3793 llvm::APInt value) {
3797 assert((getCollapseAttr() ==
nullptr) ==
3798 (getCollapseDeviceTypeAttr() ==
nullptr));
3799 assert(value.getBitWidth() == 64);
3801 if (getCollapseAttr()) {
3802 for (
const auto &existing :
3803 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3804 newValues.push_back(std::get<0>(existing));
3805 newDeviceTypes.push_back(std::get<1>(existing));
3809 if (effectiveDeviceTypes.empty()) {
3812 newValues.push_back(
3813 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3814 newDeviceTypes.push_back(
3815 acc::DeviceTypeAttr::get(context, DeviceType::None));
3817 for (DeviceType dt : effectiveDeviceTypes) {
3818 newValues.push_back(
3819 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3820 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3824 setCollapseAttr(ArrayAttr::get(context, newValues));
3825 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3828void acc::LoopOp::setTileForDeviceTypes(
3832 if (getTileOperandsSegments())
3833 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3835 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3836 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3837 getTileOperandsMutable(), segments));
3839 setTileOperandsSegments(segments);
3842void acc::LoopOp::addVectorOperand(
3845 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3846 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3847 newValue, getVectorOperandsMutable()));
3850void acc::LoopOp::addEmptyVector(
3852 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3853 effectiveDeviceTypes));
3856void acc::LoopOp::addWorkerNumOperand(
3859 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3860 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3861 newValue, getWorkerNumOperandsMutable()));
3864void acc::LoopOp::addEmptyWorker(
3866 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3867 effectiveDeviceTypes));
3870void acc::LoopOp::addEmptyGang(
3872 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3873 effectiveDeviceTypes));
3876bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3877 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3878 return attr.getValue() == dt;
3880 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3881 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3884 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3886 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3888 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3894bool acc::LoopOp::hasDefaultGangWorkerVector() {
3895 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3896 hasGang() || getGangValue(GangArgType::Num) ||
3897 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3901acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3902 if (hasSeq(deviceType))
3903 return LoopParMode::loop_seq;
3904 if (hasAuto(deviceType))
3905 return LoopParMode::loop_auto;
3906 if (hasIndependent(deviceType))
3907 return LoopParMode::loop_independent;
3909 return LoopParMode::loop_seq;
3911 return LoopParMode::loop_auto;
3912 assert(hasIndependent() &&
3913 "loop must have default auto, seq, or independent");
3914 return LoopParMode::loop_independent;
3917void acc::LoopOp::addGangOperands(
3922 getGangOperandsSegments())
3923 llvm::copy(*existingSegments, std::back_inserter(segments));
3925 unsigned beforeCount = segments.size();
3927 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3928 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3929 getGangOperandsMutable(), segments));
3931 setGangOperandsSegments(segments);
3938 unsigned numAdded = segments.size() - beforeCount;
3942 if (getGangOperandsArgTypeAttr())
3943 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3945 for (
auto i : llvm::index_range(0u, numAdded)) {
3946 llvm::transform(argTypes, std::back_inserter(gangTypes),
3947 [=](mlir::acc::GangArgType gangTy) {
3948 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3953 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3957void acc::LoopOp::addPrivatization(
MLIRContext *context,
3958 mlir::acc::PrivateOp op,
3959 mlir::acc::PrivateRecipeOp recipe) {
3960 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3961 getPrivateOperandsMutable().append(op.getResult());
3964void acc::LoopOp::addFirstPrivatization(
3965 MLIRContext *context, mlir::acc::FirstprivateOp op,
3966 mlir::acc::FirstprivateRecipeOp recipe) {
3967 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3968 getFirstprivateOperandsMutable().append(op.getResult());
3971void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3972 mlir::acc::ReductionRecipeOp recipe) {
3973 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3974 getReductionOperandsMutable().append(op.getResult());
3981LogicalResult acc::DataOp::verify() {
3986 return emitError(
"at least one operand or the default attribute "
3987 "must appear on the data operation");
3989 for (
mlir::Value operand : getDataClauseOperands())
3990 if (isa<BlockArgument>(operand) ||
3991 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3992 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3993 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3994 operand.getDefiningOp()))
3995 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4004unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
4006Value DataOp::getDataOperand(
unsigned i) {
4007 unsigned numOptional = getIfCond() ? 1 : 0;
4009 numOptional += getWaitOperands().size();
4010 return getOperand(numOptional + i);
4013bool acc::DataOp::hasAsyncOnly() {
4014 return hasAsyncOnly(mlir::acc::DeviceType::None);
4017bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4022 return getAsyncValue(mlir::acc::DeviceType::None);
4025mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4030bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
4032bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4037 return getWaitValues(mlir::acc::DeviceType::None);
4041DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4043 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4044 getHasWaitDevnum(), deviceType);
4048 return getWaitDevnum(mlir::acc::DeviceType::None);
4051mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4053 getWaitOperandsSegments(), getHasWaitDevnum(),
4057void acc::DataOp::addAsyncOnly(
4059 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4060 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4063void acc::DataOp::addAsyncOperand(
4066 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4067 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4068 getAsyncOperandsMutable()));
4071void acc::DataOp::addWaitOnly(
MLIRContext *context,
4073 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4074 effectiveDeviceTypes));
4077void acc::DataOp::addWaitOperands(
4082 if (getWaitOperandsSegments())
4083 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4085 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4086 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4087 getWaitOperandsMutable(), segments));
4088 setWaitOperandsSegments(segments);
4091 if (getHasWaitDevnumAttr())
4092 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4095 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4097 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4104LogicalResult acc::ExitDataOp::verify() {
4108 if (getDataClauseOperands().empty())
4109 return emitError(
"at least one operand must be present in dataOperands on "
4110 "the exit data operation");
4114 if (getAsyncOperand() && getAsync())
4115 return emitError(
"async attribute cannot appear with asyncOperand");
4119 if (!getWaitOperands().empty() && getWait())
4120 return emitError(
"wait attribute cannot appear with waitOperands");
4122 if (getWaitDevnum() && getWaitOperands().empty())
4123 return emitError(
"wait_devnum cannot appear without waitOperands");
4128unsigned ExitDataOp::getNumDataOperands() {
4129 return getDataClauseOperands().size();
4132Value ExitDataOp::getDataOperand(
unsigned i) {
4133 unsigned numOptional = getIfCond() ? 1 : 0;
4134 numOptional += getAsyncOperand() ? 1 : 0;
4135 numOptional += getWaitDevnum() ? 1 : 0;
4136 return getOperand(getWaitOperands().size() + numOptional + i);
4141 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
4144void ExitDataOp::addAsyncOnly(
MLIRContext *context,
4146 assert(effectiveDeviceTypes.empty());
4147 assert(!getAsyncAttr());
4148 assert(!getAsyncOperand());
4150 setAsyncAttr(mlir::UnitAttr::get(context));
4153void ExitDataOp::addAsyncOperand(
4156 assert(effectiveDeviceTypes.empty());
4157 assert(!getAsyncAttr());
4158 assert(!getAsyncOperand());
4160 getAsyncOperandMutable().append(newValue);
4165 assert(effectiveDeviceTypes.empty());
4166 assert(!getWaitAttr());
4167 assert(getWaitOperands().empty());
4168 assert(!getWaitDevnum());
4170 setWaitAttr(mlir::UnitAttr::get(context));
4173void ExitDataOp::addWaitOperands(
4176 assert(effectiveDeviceTypes.empty());
4177 assert(!getWaitAttr());
4178 assert(getWaitOperands().empty());
4179 assert(!getWaitDevnum());
4184 getWaitDevnumMutable().append(newValues.front());
4185 newValues = newValues.drop_front();
4188 getWaitOperandsMutable().append(newValues);
4195LogicalResult acc::EnterDataOp::verify() {
4199 if (getDataClauseOperands().empty())
4200 return emitError(
"at least one operand must be present in dataOperands on "
4201 "the enter data operation");
4205 if (getAsyncOperand() && getAsync())
4206 return emitError(
"async attribute cannot appear with asyncOperand");
4210 if (!getWaitOperands().empty() && getWait())
4211 return emitError(
"wait attribute cannot appear with waitOperands");
4213 if (getWaitDevnum() && getWaitOperands().empty())
4214 return emitError(
"wait_devnum cannot appear without waitOperands");
4216 for (
mlir::Value operand : getDataClauseOperands())
4217 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4218 operand.getDefiningOp()))
4219 return emitError(
"expect data entry operation as defining op");
4224unsigned EnterDataOp::getNumDataOperands() {
4225 return getDataClauseOperands().size();
4228Value EnterDataOp::getDataOperand(
unsigned i) {
4229 unsigned numOptional = getIfCond() ? 1 : 0;
4230 numOptional += getAsyncOperand() ? 1 : 0;
4231 numOptional += getWaitDevnum() ? 1 : 0;
4232 return getOperand(getWaitOperands().size() + numOptional + i);
4237 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
4240void EnterDataOp::addAsyncOnly(
4242 assert(effectiveDeviceTypes.empty());
4243 assert(!getAsyncAttr());
4244 assert(!getAsyncOperand());
4246 setAsyncAttr(mlir::UnitAttr::get(context));
4249void EnterDataOp::addAsyncOperand(
4252 assert(effectiveDeviceTypes.empty());
4253 assert(!getAsyncAttr());
4254 assert(!getAsyncOperand());
4256 getAsyncOperandMutable().append(newValue);
4259void EnterDataOp::addWaitOnly(
MLIRContext *context,
4261 assert(effectiveDeviceTypes.empty());
4262 assert(!getWaitAttr());
4263 assert(getWaitOperands().empty());
4264 assert(!getWaitDevnum());
4266 setWaitAttr(mlir::UnitAttr::get(context));
4269void EnterDataOp::addWaitOperands(
4272 assert(effectiveDeviceTypes.empty());
4273 assert(!getWaitAttr());
4274 assert(getWaitOperands().empty());
4275 assert(!getWaitDevnum());
4280 getWaitDevnumMutable().append(newValues.front());
4281 newValues = newValues.drop_front();
4284 getWaitOperandsMutable().append(newValues);
4291LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
4297LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
4303LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4310 if (
Value writeVal = op.getWriteOpVal()) {
4319LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
4321LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4327AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4328 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4330 return dyn_cast<AtomicReadOp>(getSecondOp());
4333AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4334 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4336 return dyn_cast<AtomicWriteOp>(getSecondOp());
4339AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4340 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4342 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4345LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
4351template <
typename Op>
4354 bool requireAtLeastOneOperand =
true) {
4355 if (operands.empty() && requireAtLeastOneOperand)
4358 "at least one operand must appear on the declare operation");
4361 if (isa<BlockArgument>(operand) ||
4362 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4363 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4364 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4365 operand.getDefiningOp()))
4367 "expect valid declare data entry operation or acc.getdeviceptr "
4371 assert(var &&
"declare operands can only be data entry operations which "
4374 std::optional<mlir::acc::DataClause> dataClauseOptional{
4376 assert(dataClauseOptional.has_value() &&
4377 "declare operands can only be data entry operations which must have "
4379 (
void)dataClauseOptional;
4385LogicalResult acc::DeclareEnterOp::verify() {
4393LogicalResult acc::DeclareExitOp::verify() {
4404LogicalResult acc::DeclareOp::verify() {
4413 acc::DeviceType dtype) {
4414 unsigned parallelism = 0;
4415 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4416 parallelism += op.hasWorker(dtype) ? 1 : 0;
4417 parallelism += op.hasVector(dtype) ? 1 : 0;
4418 parallelism += op.hasSeq(dtype) ? 1 : 0;
4422LogicalResult acc::RoutineOp::verify() {
4423 unsigned baseParallelism =
4426 if (baseParallelism > 1)
4427 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4428 "be present at the same time";
4430 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4432 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4433 if (dtype == acc::DeviceType::None)
4437 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4438 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4439 "be present at the same time for device_type `"
4440 << acc::stringifyDeviceType(dtype) <<
"`";
4447 mlir::ArrayAttr &bindIdName,
4448 mlir::ArrayAttr &bindStrName,
4449 mlir::ArrayAttr &deviceIdTypes,
4450 mlir::ArrayAttr &deviceStrTypes) {
4457 mlir::Attribute newAttr;
4458 bool isSymbolRefAttr;
4459 auto parseResult = parser.parseAttribute(newAttr);
4460 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4461 bindIdNameAttrs.push_back(symbolRefAttr);
4462 isSymbolRefAttr = true;
4463 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4464 bindStrNameAttrs.push_back(stringAttr);
4465 isSymbolRefAttr =
false;
4470 if (isSymbolRefAttr) {
4471 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4472 parser.getContext(), mlir::acc::DeviceType::None));
4474 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4475 parser.getContext(), mlir::acc::DeviceType::None));
4478 if (isSymbolRefAttr) {
4479 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4480 parser.parseRSquare())
4483 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4484 parser.parseRSquare())
4492 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4493 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4494 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4495 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4501 std::optional<mlir::ArrayAttr> bindIdName,
4502 std::optional<mlir::ArrayAttr> bindStrName,
4503 std::optional<mlir::ArrayAttr> deviceIdTypes,
4504 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4511 allBindNames.append(bindIdName->begin(), bindIdName->end());
4512 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4517 allBindNames.append(bindStrName->begin(), bindStrName->end());
4518 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4522 if (!allBindNames.empty())
4523 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4524 [&](
const auto &pair) {
4525 p << std::get<0>(pair);
4531 mlir::ArrayAttr &gang,
4532 mlir::ArrayAttr &gangDim,
4533 mlir::ArrayAttr &gangDimDeviceTypes) {
4536 gangDimDeviceTypeAttrs;
4537 bool needCommaBeforeOperands =
false;
4541 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4542 parser.
getContext(), mlir::acc::DeviceType::None));
4543 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4550 if (parser.parseAttribute(gangAttrs.emplace_back()))
4557 needCommaBeforeOperands =
true;
4560 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4564 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4565 parser.parseColon() ||
4566 parser.parseAttribute(gangDimAttrs.emplace_back()))
4568 if (succeeded(parser.parseOptionalLSquare())) {
4569 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4570 parser.parseRSquare())
4573 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4574 parser.getContext(), mlir::acc::DeviceType::None));
4580 if (
failed(parser.parseRParen()))
4583 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4584 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4585 gangDimDeviceTypes =
4586 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4592 std::optional<mlir::ArrayAttr> gang,
4593 std::optional<mlir::ArrayAttr> gangDim,
4594 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4597 gang->size() == 1) {
4598 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4599 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4611 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4612 [&](
const auto &pair) {
4613 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4614 p << std::get<0>(pair);
4622 mlir::ArrayAttr &deviceTypes) {
4626 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4627 parser.
getContext(), mlir::acc::DeviceType::None));
4628 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4635 if (parser.parseAttribute(attributes.emplace_back()))
4643 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4649 std::optional<mlir::ArrayAttr> deviceTypes) {
4652 auto deviceTypeAttr =
4653 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4654 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4663 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4669bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4671bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4675bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4677bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4681bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4683bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4687std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4688RoutineOp::getBindNameValue() {
4689 return getBindNameValue(mlir::acc::DeviceType::None);
4692std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4693RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4695 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4696 auto attr = (*getBindIdName())[*pos];
4697 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4698 assert(symbolRefAttr &&
"expected SymbolRef");
4699 return symbolRefAttr;
4704 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4705 auto attr = (*getBindStrName())[*pos];
4706 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4707 assert(stringAttr &&
"expected String");
4712 return std::nullopt;
4715bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4717bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4721std::optional<int64_t> RoutineOp::getGangDimValue() {
4722 return getGangDimValue(mlir::acc::DeviceType::None);
4725std::optional<int64_t>
4726RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4728 return std::nullopt;
4729 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4730 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4731 return intAttr.getInt();
4733 return std::nullopt;
4738 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4739 effectiveDeviceTypes));
4744 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4745 effectiveDeviceTypes));
4750 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4751 effectiveDeviceTypes));
4756 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4757 effectiveDeviceTypes));
4766 if (getGangDimAttr())
4767 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4768 if (getGangDimDeviceTypeAttr())
4769 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4771 assert(dimValues.size() == deviceTypes.size());
4773 if (effectiveDeviceTypes.empty()) {
4774 dimValues.push_back(
4775 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4776 deviceTypes.push_back(
4777 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4779 for (DeviceType dt : effectiveDeviceTypes) {
4780 dimValues.push_back(
4781 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4782 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4785 assert(dimValues.size() == deviceTypes.size());
4787 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4788 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4791void RoutineOp::addBindStrName(
MLIRContext *context,
4793 mlir::StringAttr val) {
4794 unsigned before = getBindStrNameDeviceTypeAttr()
4795 ? getBindStrNameDeviceTypeAttr().size()
4798 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4799 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4800 unsigned after = getBindStrNameDeviceTypeAttr().size();
4803 if (getBindStrNameAttr())
4804 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4805 for (
unsigned i = 0; i < after - before; ++i)
4806 vals.push_back(val);
4808 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4811void RoutineOp::addBindIDName(
MLIRContext *context,
4813 mlir::SymbolRefAttr val) {
4815 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4817 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4818 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4819 unsigned after = getBindIdNameDeviceTypeAttr().size();
4822 if (getBindIdNameAttr())
4823 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4824 for (
unsigned i = 0; i < after - before; ++i)
4825 vals.push_back(val);
4827 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4834LogicalResult acc::InitOp::verify() {
4835 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4836 return emitOpError(
"cannot be nested in a compute operation");
4840void acc::InitOp::addDeviceType(
MLIRContext *context,
4841 mlir::acc::DeviceType deviceType) {
4843 if (getDeviceTypesAttr())
4844 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4846 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4847 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4854LogicalResult acc::ShutdownOp::verify() {
4855 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4856 return emitOpError(
"cannot be nested in a compute operation");
4860void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4861 mlir::acc::DeviceType deviceType) {
4863 if (getDeviceTypesAttr())
4864 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4866 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4867 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4874LogicalResult acc::SetOp::verify() {
4875 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4876 return emitOpError(
"cannot be nested in a compute operation");
4877 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4878 return emitOpError(
"at least one default_async, device_num, or device_type "
4879 "operand must appear");
4887LogicalResult acc::UpdateOp::verify() {
4889 if (getDataClauseOperands().empty())
4890 return emitError(
"at least one value must be present in dataOperands");
4893 getAsyncOperandsDeviceTypeAttr(),
4898 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4899 getWaitOperandsDeviceTypeAttr(),
"wait")))
4905 for (
mlir::Value operand : getDataClauseOperands())
4906 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4907 operand.getDefiningOp()))
4908 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4914unsigned UpdateOp::getNumDataOperands() {
4915 return getDataClauseOperands().size();
4918Value UpdateOp::getDataOperand(
unsigned i) {
4920 numOptional += getIfCond() ? 1 : 0;
4921 return getOperand(getWaitOperands().size() + numOptional + i);
4926 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
4929bool UpdateOp::hasAsyncOnly() {
4930 return hasAsyncOnly(mlir::acc::DeviceType::None);
4933bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4938 return getAsyncValue(mlir::acc::DeviceType::None);
4941mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4951bool UpdateOp::hasWaitOnly() {
4952 return hasWaitOnly(mlir::acc::DeviceType::None);
4955bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4960 return getWaitValues(mlir::acc::DeviceType::None);
4964UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4966 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4967 getHasWaitDevnum(), deviceType);
4971 return getWaitDevnum(mlir::acc::DeviceType::None);
4974mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4976 getWaitOperandsSegments(), getHasWaitDevnum(),
4982 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4983 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4986void UpdateOp::addAsyncOperand(
4989 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4990 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4991 getAsyncOperandsMutable()));
4996 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4997 effectiveDeviceTypes));
5000void UpdateOp::addWaitOperands(
5005 if (getWaitOperandsSegments())
5006 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
5008 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5009 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
5010 getWaitOperandsMutable(), segments));
5011 setWaitOperandsSegments(segments);
5014 if (getHasWaitDevnumAttr())
5015 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5018 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
5020 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5027LogicalResult acc::WaitOp::verify() {
5030 if (getAsyncOperand() && getAsync())
5031 return emitError(
"async attribute cannot appear with asyncOperand");
5033 if (getWaitDevnum() && getWaitOperands().empty())
5034 return emitError(
"wait_devnum cannot appear without waitOperands");
5039#define GET_OP_CLASSES
5040#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5042#define GET_ATTRDEF_CLASSES
5043#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5045#define GET_TYPEDEF_CLASSES
5046#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5057 .Case<ACC_DATA_ENTRY_OPS>(
5058 [&](
auto entry) {
return entry.getVarPtr(); })
5059 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5060 [&](
auto exit) {
return exit.getVarPtr(); })
5078 [&](
auto entry) {
return entry.getVarType(); })
5079 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5080 [&](
auto exit) {
return exit.getVarType(); })
5090 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5091 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
5101 [&](
auto dataClause) {
return dataClause.getAccVar(); })
5110 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
5120 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5122 dataClause.getBounds().begin(), dataClause.getBounds().end());
5134 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5136 dataClause.getAsyncOperands().begin(),
5137 dataClause.getAsyncOperands().end());
5148 return dataClause.getAsyncOperandsDeviceTypeAttr();
5156 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
5163 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
5170std::optional<mlir::acc::DataClause>
5175 .Case<ACC_DATA_ENTRY_OPS>(
5176 [&](
auto entry) {
return entry.getDataClause(); })
5184 [&](
auto entry) {
return entry.getImplicit(); })
5193 [&](
auto entry) {
return entry.getDataClauseOperands(); })
5195 return dataOperands;
5203 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
5205 return dataOperands;
5212 [&](
auto entry) {
return entry.getRecipeAttr(); })
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
unsigned size() const
Returns the current size of the range.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static CurrentDeviceIdResource * get()
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Base attribute class for language-specific variable information carried through the OpenACC type inte...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
#define ACC_COMPUTE_CONSTRUCT_OPS
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.