26#include "llvm/ADT/SmallSet.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/LogicalResult.h"
34#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
38#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
42static bool isScalarLikeType(
Type type) {
50 if (!varName.empty()) {
51 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
57struct MemRefPointerLikeModel
58 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
60 return cast<T>(pointer).getElementType();
63 mlir::acc::VariableTypeCategory
66 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
67 return mappableTy.getTypeCategory(varPtr);
69 auto memrefTy = cast<T>(pointer);
70 if (!memrefTy.hasRank()) {
73 return mlir::acc::VariableTypeCategory::uncategorized;
76 if (memrefTy.getRank() == 0) {
77 if (isScalarLikeType(memrefTy.getElementType())) {
78 return mlir::acc::VariableTypeCategory::scalar;
82 return mlir::acc::VariableTypeCategory::uncategorized;
86 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
87 return mlir::acc::VariableTypeCategory::array;
90 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
91 StringRef varName, Type varType, Value originalVar,
92 bool &needsFree)
const {
93 auto memrefTy = cast<MemRefType>(pointer);
97 if (memrefTy.hasStaticShape()) {
99 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
100 attachVarNameAttr(allocaOp, builder, varName);
101 return allocaOp.getResult();
106 if (originalVar && originalVar.
getType() == memrefTy &&
107 memrefTy.hasRank()) {
108 SmallVector<Value> dynamicSizes;
109 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
110 if (memrefTy.isDynamicDim(i)) {
114 memref::DimOp::create(builder, loc, originalVar, indexValue);
115 dynamicSizes.push_back(dimSize);
122 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
123 attachVarNameAttr(allocOp, builder, varName);
124 return allocOp.getResult();
131 bool genFree(Type pointer, OpBuilder &builder, Location loc,
133 Type varType)
const {
136 Value valueToInspect = allocRes ? allocRes : memrefValue;
139 Value currentValue = valueToInspect;
140 Operation *originalAlloc =
nullptr;
144 while (currentValue) {
147 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
148 originalAlloc = definingOp;
153 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
154 currentValue = castOp.getSource();
159 if (
auto reinterpretCastOp =
160 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
161 currentValue = reinterpretCastOp.getSource();
173 if (isa<memref::AllocaOp>(originalAlloc)) {
177 if (isa<memref::AllocOp>(originalAlloc)) {
179 memref::DeallocOp::create(builder, loc, memrefValue);
188 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
192 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
193 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
199 if (destMemref && srcMemref &&
200 destMemref.getType().getElementType() ==
201 srcMemref.getType().getElementType() &&
202 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
203 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
210 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
212 Type valueType)
const {
217 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
221 auto memrefTy = memrefValue.
getType();
224 if (memrefTy.getRank() != 0)
227 return memref::LoadOp::create(builder, loc, memrefValue);
230 bool genStore(Type pointer, OpBuilder &builder, Location loc,
236 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
240 auto memrefTy = memrefValue.getType();
243 if (memrefTy.getRank() != 0)
246 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
250 Value
genCast(Type, OpBuilder &builder, Location loc, Value value,
251 Type resultType)
const {
252 if (value.
getType() == resultType)
255 if (isa<BaseMemRefType>(value.
getType()) &&
256 isa<BaseMemRefType>(resultType)) {
259 return memref::CastOp::create(builder, loc, resultType, value);
260 if (memref::MemorySpaceCastOp::areCastCompatible(
262 return memref::MemorySpaceCastOp::create(builder, loc, resultType,
269 if (
auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
270 if (!isa<BaseMemRefType>(resPtrLike))
271 if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
273 if (
auto valPtrLike = dyn_cast<PointerLikeType>(value.
getType()))
274 if (!isa<BaseMemRefType>(valPtrLike))
275 if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
281 bool isDeviceData(Type pointer, Value var)
const {
282 auto memrefTy = cast<T>(pointer);
283 Attribute memSpace = memrefTy.getMemorySpace();
284 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
287 MemRefType getAsMemRefType(Type pointer, ModuleOp module)
const {
289 return dyn_cast<MemRefType>(pointer);
293struct LLVMPointerPointerLikeModel
294 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
295 LLVM::LLVMPointerType> {
298 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
300 Type valueType)
const {
305 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
308 bool genStore(Type pointer, OpBuilder &builder, Location loc,
310 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
314 Value
genCast(Type, OpBuilder &builder, Location loc, Value value,
315 Type resultType)
const {
316 if (value.
getType() == resultType)
319 auto srcPtrTy = dyn_cast<LLVM::LLVMPointerType>(value.
getType());
320 auto dstPtrTy = dyn_cast<LLVM::LLVMPointerType>(resultType);
321 if (srcPtrTy && dstPtrTy) {
322 if (srcPtrTy.getAddressSpace() != dstPtrTy.getAddressSpace())
323 return LLVM::AddrSpaceCastOp::create(builder, loc, resultType, value);
327 if (srcPtrTy && isa<IntegerType>(resultType))
328 return LLVM::PtrToIntOp::create(builder, loc, resultType, value);
331 Value intVal = value;
332 if (isa<IndexType>(value.
getType()))
333 intVal = arith::IndexCastUIOp::create(builder, loc,
335 if (isa<IntegerType>(intVal.
getType()))
336 return LLVM::IntToPtrOp::create(builder, loc, resultType, intVal);
339 if (
auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
340 if (!isa<LLVM::LLVMPointerType>(resPtrLike))
341 if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
343 if (
auto valPtrLike = dyn_cast<PointerLikeType>(value.
getType()))
344 if (!isa<LLVM::LLVMPointerType>(valPtrLike))
345 if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
348 return UnrealizedConversionCastOp::create(builder, loc,
354struct PrivateTypePointerLikeModel
355 :
public PointerLikeType::ExternalModel<PrivateTypePointerLikeModel,
358 return cast<PrivateType>(type).getBaseTy();
361 Value
genCast(Type, OpBuilder &builder, Location loc, Value value,
362 Type resultType)
const {
363 if (value.
getType() == resultType)
365 if (!isa<PointerLikeType>(resultType))
367 return UnwrapPrivateOp::create(builder, loc, resultType, value).getResult();
370 MemRefType getAsMemRefType(Type type, ModuleOp module)
const {
371 Type baseTy = cast<PrivateType>(type).getBaseTy();
372 if (
auto memrefTy = dyn_cast<MemRefType>(baseTy))
374 if (
auto ptrLikeTy = dyn_cast<PointerLikeType>(baseTy))
375 return ptrLikeTy.getAsMemRefType(module);
380struct MemrefAddressOfGlobalModel
381 :
public AddressOfGlobalOpInterface::ExternalModel<
382 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
383 SymbolRefAttr getSymbol(Operation *op)
const {
384 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
385 return getGlobalOp.getNameAttr();
389struct MemrefGlobalVariableModel
390 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
392 bool isConstant(Operation *op)
const {
393 auto globalOp = cast<memref::GlobalOp>(op);
394 return globalOp.getConstant();
397 Region *getInitRegion(Operation *op)
const {
402 bool isDeviceData(Operation *op)
const {
403 auto globalOp = cast<memref::GlobalOp>(op);
404 Attribute memSpace = globalOp.getType().getMemorySpace();
405 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
409struct GPULaunchOffloadRegionModel
410 :
public acc::OffloadRegionOpInterface::ExternalModel<
411 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
412 mlir::Region &getOffloadRegion(mlir::Operation *op)
const {
413 return cast<gpu::LaunchOp>(op).getBody();
421mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
422 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
425 if (existingDeviceTypes)
426 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
428 if (newDeviceTypes.empty())
429 deviceTypes.push_back(
430 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
432 for (DeviceType dt : newDeviceTypes)
433 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
435 return mlir::ArrayAttr::get(context, deviceTypes);
444mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
445 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
450 if (existingDeviceTypes)
451 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
453 if (newDeviceTypes.empty()) {
454 argCollection.
append(arguments);
455 segments.push_back(arguments.size());
456 deviceTypes.push_back(
457 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
460 for (DeviceType dt : newDeviceTypes) {
461 argCollection.
append(arguments);
462 segments.push_back(arguments.size());
463 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
466 return mlir::ArrayAttr::get(context, deviceTypes);
470mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
471 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
475 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
476 newDeviceTypes, arguments,
477 argCollection, segments);
485void OpenACCDialect::initialize() {
488#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
491#define GET_ATTRDEF_LIST
492#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
495#define GET_TYPEDEF_LIST
496#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
502 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
504 UnrankedMemRefType::attachInterface<
505 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
506 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
508 PrivateType::attachInterface<PrivateTypePointerLikeModel>(*
getContext());
511 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
513 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
514 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*
getContext());
551void ParallelOp::getSuccessorRegions(
581void HostDataOp::getSuccessorRegions(
596 if (getUnstructured()) {
629 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
633 mlir::acc::DeviceType deviceType) {
637 for (
auto attr : *arrayAttr) {
638 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
639 if (deviceTypeAttr.getValue() == deviceType)
647 std::optional<mlir::ArrayAttr> deviceTypes) {
652 llvm::interleaveComma(*deviceTypes, p,
658 mlir::acc::DeviceType deviceType) {
659 unsigned segmentIdx = 0;
660 for (
auto attr : segments) {
661 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
662 if (deviceTypeAttr.getValue() == deviceType)
663 return std::make_optional(segmentIdx);
673 mlir::acc::DeviceType deviceType) {
675 return range.take_front(0);
676 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
677 int32_t nbOperandsBefore = 0;
678 for (
unsigned i = 0; i < *pos; ++i)
679 nbOperandsBefore += (*segments)[i];
680 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
682 return range.take_front(0);
689 std::optional<mlir::ArrayAttr> hasWaitDevnum,
690 mlir::acc::DeviceType deviceType) {
693 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
694 if (hasWaitDevnum && *hasWaitDevnum) {
695 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
696 if (boolAttr && boolAttr.getValue())
709 std::optional<mlir::ArrayAttr> hasWaitDevnum,
710 mlir::acc::DeviceType deviceType) {
715 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
716 if (hasWaitDevnum && *hasWaitDevnum) {
717 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
718 if (boolAttr.getValue())
719 return range.drop_front(1);
725template <
typename Op>
727 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
729 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
734 op.hasAsyncOnly(dtype))
736 "asyncOnly attribute cannot appear with asyncOperand");
741 op.hasWaitOnly(dtype))
742 return op.
emitError(
"wait attribute cannot appear with waitOperands");
747template <
typename Op>
750 return op.
emitError(
"must have var operand");
753 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
754 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
755 return op.
emitError(
"var must be mappable or pointer-like");
758 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
759 op.getVarType() == op.getVar().getType())
760 return op.
emitError(
"varType must capture the element type of var");
765template <
typename Op>
767 if (op.getVar().getType() != op.getAccVar().getType())
768 return op.
emitError(
"input and output types must match");
773template <
typename Op>
775 if (op.getModifiers() != acc::DataClauseModifier::none)
776 return op.
emitError(
"no data clause modifiers are allowed");
780template <
typename Op>
783 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
785 "invalid data clause modifiers: " +
786 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
791template <
typename OpT,
typename RecipeOpT>
792static LogicalResult
checkRecipe(OpT op, llvm::StringRef operandName) {
797 !std::is_same_v<OpT, acc::ReductionOp>)
800 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
802 return op->emitOpError() <<
"recipe expected for " << operandName;
807 return op->emitOpError()
808 <<
"expected symbol reference " << operandRecipe <<
" to point to a "
809 << operandName <<
" declaration";
830 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
851 if (failed(parser.
parseType(accVarType)))
861 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
873 mlir::TypeAttr &varTypeAttr) {
874 if (failed(parser.
parseType(varPtrType)))
885 varTypeAttr = mlir::TypeAttr::get(varType);
890 if (
auto ptrTy = dyn_cast<acc::PointerLikeType>(varPtrType)) {
891 Type elementType = ptrTy.getElementType();
894 varTypeAttr = mlir::TypeAttr::get(elementType ? elementType : varPtrType);
896 varTypeAttr = mlir::TypeAttr::get(varPtrType);
904 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
912 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
913 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
917 if (!typeToCheckAgainst)
918 typeToCheckAgainst = varPtrType;
919 if (typeToCheckAgainst != varType) {
927 mlir::SymbolRefAttr &recipeAttr) {
934 mlir::SymbolRefAttr recipeAttr) {
941LogicalResult acc::DataBoundsOp::verify() {
942 auto extent = getExtent();
943 auto upperbound = getUpperbound();
944 if (!extent && !upperbound)
945 return emitError(
"expected extent or upperbound.");
952LogicalResult acc::PrivateOp::verify() {
955 "data clause associated with private operation must match its intent");
969LogicalResult acc::FirstprivateOp::verify() {
971 return emitError(
"data clause associated with firstprivate operation must "
978 *
this,
"firstprivate")))
986LogicalResult acc::ReductionOp::verify() {
988 return emitError(
"data clause associated with reduction operation must "
995 *
this,
"reduction")))
1003LogicalResult acc::DevicePtrOp::verify() {
1005 return emitError(
"data clause associated with deviceptr operation must "
1006 "match its intent");
1019LogicalResult acc::PresentOp::verify() {
1022 "data clause associated with present operation must match its intent");
1035LogicalResult acc::CopyinOp::verify() {
1037 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
1042 "data clause associated with copyin operation must match its intent"
1043 " or specify original clause this operation was decomposed from");
1049 acc::DataClauseModifier::always |
1050 acc::DataClauseModifier::capture)))
1055bool acc::CopyinOp::isCopyinReadonly() {
1056 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
1057 acc::bitEnumContainsAny(getModifiers(),
1058 acc::DataClauseModifier::readonly);
1064LogicalResult acc::CreateOp::verify() {
1071 "data clause associated with create operation must match its intent"
1072 " or specify original clause this operation was decomposed from");
1080 acc::DataClauseModifier::always |
1081 acc::DataClauseModifier::capture)))
1086bool acc::CreateOp::isCreateZero() {
1088 return getDataClause() == acc::DataClause::acc_create_zero ||
1090 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1096LogicalResult acc::NoCreateOp::verify() {
1098 return emitError(
"data clause associated with no_create operation must "
1099 "match its intent");
1112LogicalResult acc::AttachOp::verify() {
1115 "data clause associated with attach operation must match its intent");
1129LogicalResult acc::DeclareDeviceResidentOp::verify() {
1130 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
1131 return emitError(
"data clause associated with device_resident operation "
1132 "must match its intent");
1146LogicalResult acc::DeclareLinkOp::verify() {
1149 "data clause associated with link operation must match its intent");
1162LogicalResult acc::CopyoutOp::verify() {
1169 "data clause associated with copyout operation must match its intent"
1170 " or specify original clause this operation was decomposed from");
1172 return emitError(
"must have both host and device pointers");
1178 acc::DataClauseModifier::always |
1179 acc::DataClauseModifier::capture)))
1184bool acc::CopyoutOp::isCopyoutZero() {
1185 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1186 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1192LogicalResult acc::DeleteOp::verify() {
1201 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1204 "data clause associated with delete operation must match its intent"
1205 " or specify original clause this operation was decomposed from");
1207 return emitError(
"must have device pointer");
1211 acc::DataClauseModifier::readonly |
1212 acc::DataClauseModifier::always |
1213 acc::DataClauseModifier::capture)))
1221LogicalResult acc::DetachOp::verify() {
1226 "data clause associated with detach operation must match its intent"
1227 " or specify original clause this operation was decomposed from");
1229 return emitError(
"must have device pointer");
1238LogicalResult acc::UpdateHostOp::verify() {
1243 "data clause associated with host operation must match its intent"
1244 " or specify original clause this operation was decomposed from");
1246 return emitError(
"must have both host and device pointers");
1259LogicalResult acc::UpdateDeviceOp::verify() {
1263 "data clause associated with device operation must match its intent"
1264 " or specify original clause this operation was decomposed from");
1277LogicalResult acc::UseDeviceOp::verify() {
1281 "data clause associated with use_device operation must match its intent"
1282 " or specify original clause this operation was decomposed from");
1295LogicalResult acc::CacheOp::verify() {
1300 "data clause associated with cache operation must match its intent"
1301 " or specify original clause this operation was decomposed from");
1311bool acc::CacheOp::isCacheReadonly() {
1312 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1313 acc::bitEnumContainsAny(getModifiers(),
1314 acc::DataClauseModifier::readonly);
1330template <
typename EffectTy>
1335 for (
unsigned i = 0, e = operand.
size(); i < e; ++i)
1336 effects.emplace_back(EffectTy::get(), &operand[i]);
1340template <
typename EffectTy>
1345 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(
result));
1349void acc::PrivateOp::getEffects(
1363void acc::FirstprivateOp::getEffects(
1377void acc::ReductionOp::getEffects(
1391void acc::DevicePtrOp::getEffects(
1400void acc::PresentOp::getEffects(
1411void acc::CopyinOp::getEffects(
1424void acc::CreateOp::getEffects(
1437void acc::NoCreateOp::getEffects(
1448void acc::AttachOp::getEffects(
1461void acc::GetDevicePtrOp::getEffects(
1470void acc::UpdateDeviceOp::getEffects(
1480void acc::UseDeviceOp::getEffects(
1489void acc::DeclareDeviceResidentOp::getEffects(
1500void acc::DeclareLinkOp::getEffects(
1511void acc::CacheOp::getEffects(
1516void acc::CopyoutOp::getEffects(
1529void acc::DeleteOp::getEffects(
1541void acc::DetachOp::getEffects(
1553void acc::UpdateHostOp::getEffects(
1565template <
typename StructureOp>
1567 unsigned nRegions = 1) {
1570 for (
unsigned i = 0; i < nRegions; ++i)
1573 for (
Region *region : regions)
1584template <
typename OpTy>
1586 using OpRewritePattern<OpTy>::OpRewritePattern;
1588 LogicalResult matchAndRewrite(OpTy op,
1589 PatternRewriter &rewriter)
const override {
1591 Value ifCond = op.getIfCond();
1595 IntegerAttr constAttr;
1598 if (constAttr.getInt())
1599 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1611 assert(region.
hasOneBlock() &&
"expected single-block region");
1623template <
typename OpTy>
1624struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1625 using OpRewritePattern<OpTy>::OpRewritePattern;
1627 LogicalResult matchAndRewrite(OpTy op,
1628 PatternRewriter &rewriter)
const override {
1630 Value ifCond = op.getIfCond();
1634 IntegerAttr constAttr;
1637 if (constAttr.getInt())
1638 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1666 for (
Value bound : bounds) {
1667 argTypes.push_back(bound.getType());
1668 argLocs.push_back(loc);
1675 Value privatizedValue;
1681 if (isa<MappableType>(varType)) {
1682 auto mappableTy = cast<MappableType>(varType);
1683 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1684 auto typedHostVar = cast<TypedValue<MappableType>>(hostVar);
1685 varInfo = mappableTy.genPrivateVariableInfo(typedHostVar);
1686 privatizedValue = mappableTy.generatePrivateInit(
1687 builder, loc, typedVar, varName, bounds, {}, varInfo, needsFree);
1688 if (!privatizedValue)
1691 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1692 auto pointerLikeTy = cast<PointerLikeType>(varType);
1694 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1695 blockArgVar, needsFree);
1696 if (!privatizedValue)
1701 acc::YieldOp::create(builder, loc, privatizedValue);
1718 for (
Value bound : bounds) {
1719 copyArgTypes.push_back(bound.getType());
1720 copyArgLocs.push_back(loc);
1730 if (isa<MappableType>(varType)) {
1731 auto mappableTy = cast<MappableType>(varType);
1734 if (!mappableTy.generateCopy(
1739 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1740 auto pointerLikeTy = cast<PointerLikeType>(varType);
1741 if (!pointerLikeTy.genCopy(
1748 acc::TerminatorOp::create(builder, loc);
1765 for (
Value bound : bounds) {
1766 destroyArgTypes.push_back(bound.getType());
1767 destroyArgLocs.push_back(loc);
1771 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1775 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1776 if (isa<MappableType>(varType)) {
1777 auto mappableTy = cast<MappableType>(varType);
1778 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds,
1782 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1783 auto pointerLikeTy = cast<PointerLikeType>(varType);
1784 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1788 acc::TerminatorOp::create(builder, loc);
1799 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1801 if (optional && region.
empty())
1805 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1809 return op->
emitOpError() <<
"expects " << regionName
1812 << regionType <<
" type";
1815 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1816 if (yieldOp.getOperands().size() != 1 ||
1817 yieldOp.getOperands().getTypes()[0] != type)
1818 return op->
emitOpError() <<
"expects " << regionName
1820 "yield a value of the "
1821 << regionType <<
" type";
1827LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1829 "privatization",
"init",
getType(),
1833 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1839std::optional<PrivateRecipeOp>
1841 StringRef recipeName,
Value hostVar,
1846 bool isMappable = isa<MappableType>(varType);
1847 bool isPointerLike = isa<PointerLikeType>(varType);
1850 if (!isMappable && !isPointerLike)
1851 return std::nullopt;
1856 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1859 bool needsFree =
false;
1861 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1862 varName, bounds, needsFree, varInfo))) {
1864 return std::nullopt;
1871 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1872 Value allocRes = yieldOp.getOperand(0);
1874 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1875 varType, allocRes, bounds, varInfo))) {
1877 return std::nullopt;
1884std::optional<PrivateRecipeOp>
1886 StringRef recipeName,
1887 FirstprivateRecipeOp firstprivRecipe) {
1890 auto varType = firstprivRecipe.getType();
1891 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1895 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1898 if (!firstprivRecipe.getDestroyRegion().empty()) {
1900 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1910LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1912 "privatization",
"init",
getType(),
1916 if (getCopyRegion().empty())
1917 return emitOpError() <<
"expects non-empty copy region";
1922 return emitOpError() <<
"expects copy region with two arguments of the "
1923 "privatization type";
1925 if (getDestroyRegion().empty())
1929 "privatization",
"destroy",
1936std::optional<FirstprivateRecipeOp>
1938 StringRef recipeName,
Value hostVar,
1943 bool isMappable = isa<MappableType>(varType);
1944 bool isPointerLike = isa<PointerLikeType>(varType);
1947 if (!isMappable && !isPointerLike)
1948 return std::nullopt;
1953 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1956 bool needsFree =
false;
1961 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1962 varName, bounds, needsFree, varInfo))) {
1964 return std::nullopt;
1968 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1969 bounds, varInfo))) {
1971 return std::nullopt;
1978 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1979 Value allocRes = yieldOp.getOperand(0);
1981 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1982 varType, allocRes, bounds, varInfo))) {
1984 return std::nullopt;
1995LogicalResult acc::ReductionRecipeOp::verifyRegions() {
2001 if (getCombinerRegion().empty())
2002 return emitOpError() <<
"expects non-empty combiner region";
2004 Block &reductionBlock = getCombinerRegion().
front();
2008 return emitOpError() <<
"expects combiner region with the first two "
2009 <<
"arguments of the reduction type";
2011 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
2012 if (yieldOp.getOperands().size() != 1 ||
2013 yieldOp.getOperands().getTypes()[0] !=
getType())
2014 return emitOpError() <<
"expects combiner region to yield a value "
2015 "of the reduction type";
2026template <
typename Op>
2030 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2031 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2032 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2033 operand.getDefiningOp()))
2035 "expect data entry/exit operation or acc.getdeviceptr "
2040template <
typename OpT,
typename RecipeOpT>
2043 llvm::StringRef operandName) {
2046 if (!mlir::isa<OpT>(operand.getDefiningOp()))
2048 <<
"expected " << operandName <<
" as defining op";
2049 if (!set.insert(operand).second)
2051 << operandName <<
" operand appears more than once";
2056unsigned ParallelOp::getNumDataOperands() {
2057 return getReductionOperands().size() + getPrivateOperands().size() +
2058 getFirstprivateOperands().size() + getDataClauseOperands().size();
2061Value ParallelOp::getDataOperand(
unsigned i) {
2063 numOptional += getNumGangs().size();
2064 numOptional += getNumWorkers().size();
2065 numOptional += getVectorLength().size();
2066 numOptional += getIfCond() ? 1 : 0;
2067 numOptional += getSelfCond() ? 1 : 0;
2068 return getOperand(getWaitOperands().size() + numOptional + i);
2071template <
typename Op>
2074 llvm::StringRef keyword) {
2075 if (!operands.empty() &&
2076 (!deviceTypes || deviceTypes.getValue().size() != operands.size()))
2077 return op.
emitOpError() << keyword <<
" operands count must match "
2078 << keyword <<
" device_type count";
2082template <
typename Op>
2085 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
2086 std::size_t numOperandsInSegments = 0;
2087 std::size_t nbOfSegments = 0;
2090 for (
auto segCount : segments.
asArrayRef()) {
2091 if (maxInSegment != 0 && segCount > maxInSegment)
2092 return op.
emitOpError() << keyword <<
" expects a maximum of "
2093 << maxInSegment <<
" values per segment";
2094 numOperandsInSegments += segCount;
2099 if ((numOperandsInSegments != operands.size()) ||
2100 (!deviceTypes && !operands.empty()))
2102 << keyword <<
" operand count does not match count in segments";
2103 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
2105 << keyword <<
" segment count does not match device_type count";
2109LogicalResult acc::ParallelOp::verify() {
2111 mlir::acc::PrivateRecipeOp>(
2112 *
this, getPrivateOperands(),
"private")))
2115 mlir::acc::FirstprivateRecipeOp>(
2116 *
this, getFirstprivateOperands(),
"firstprivate")))
2119 mlir::acc::ReductionRecipeOp>(
2120 *
this, getReductionOperands(),
"reduction")))
2124 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2125 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2129 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2130 getWaitOperandsDeviceTypeAttr(),
"wait")))
2134 getNumWorkersDeviceTypeAttr(),
2139 getVectorLengthDeviceTypeAttr(),
2144 getAsyncOperandsDeviceTypeAttr(),
2157 mlir::acc::DeviceType deviceType) {
2160 if (
auto pos =
findSegment(*arrayAttr, deviceType))
2165bool acc::ParallelOp::hasAsyncOnly() {
2166 return hasAsyncOnly(mlir::acc::DeviceType::None);
2169bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2174 return getAsyncValue(mlir::acc::DeviceType::None);
2177mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2182mlir::Value acc::ParallelOp::getNumWorkersValue() {
2183 return getNumWorkersValue(mlir::acc::DeviceType::None);
2187acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2192mlir::Value acc::ParallelOp::getVectorLengthValue() {
2193 return getVectorLengthValue(mlir::acc::DeviceType::None);
2197acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2199 getVectorLength(), deviceType);
2203 return getNumGangsValues(mlir::acc::DeviceType::None);
2207ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2209 getNumGangsSegments(), deviceType);
2213 std::optional<mlir::ArrayAttr> numGangsDeviceType,
2216 std::optional<mlir::ArrayAttr> numWorkersDeviceType,
2218 std::optional<mlir::ArrayAttr> vectorLengthDeviceType,
2220 mlir::acc::DeviceType deviceType) {
2230bool acc::ParallelOp::hasAnyGangWorkerVector(mlir::acc::DeviceType deviceType) {
2232 getNumGangsDeviceType(), getNumGangs(), getNumGangsSegments(),
2233 getNumWorkersDeviceType(), getNumWorkers(), getVectorLengthDeviceType(),
2234 getVectorLength(), deviceType);
2237bool acc::ParallelOp::hasWaitOnly() {
2238 return hasWaitOnly(mlir::acc::DeviceType::None);
2241bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2246 return getWaitValues(mlir::acc::DeviceType::None);
2250ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2252 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2253 getHasWaitDevnum(), deviceType);
2257 return getWaitDevnum(mlir::acc::DeviceType::None);
2260mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2262 getWaitOperandsSegments(), getHasWaitDevnum(),
2277 odsBuilder, odsState, asyncOperands,
nullptr,
2278 nullptr, waitOperands,
nullptr,
2280 nullptr, numGangs,
nullptr,
2281 nullptr, numWorkers,
2282 nullptr, vectorLength,
2283 nullptr, ifCond, selfCond,
2284 nullptr, reductionOperands, gangPrivateOperands,
2285 gangFirstPrivateOperands, dataClauseOperands,
2289void acc::ParallelOp::addNumWorkersOperand(
2292 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2293 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2294 getNumWorkersMutable()));
2296void acc::ParallelOp::addVectorLengthOperand(
2299 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2300 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2301 getVectorLengthMutable()));
2304void acc::ParallelOp::addAsyncOnly(
2306 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2307 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2310void acc::ParallelOp::addAsyncOperand(
2313 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2314 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2315 getAsyncOperandsMutable()));
2318void acc::ParallelOp::addNumGangsOperands(
2322 if (getNumGangsSegments())
2323 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2325 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2326 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2327 getNumGangsMutable(), segments));
2329 setNumGangsSegments(segments);
2331void acc::ParallelOp::addWaitOnly(
2333 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2334 effectiveDeviceTypes));
2336void acc::ParallelOp::addWaitOperands(
2341 if (getWaitOperandsSegments())
2342 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2344 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2345 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2346 getWaitOperandsMutable(), segments));
2347 setWaitOperandsSegments(segments);
2350 if (getHasWaitDevnumAttr())
2351 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2354 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2356 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2359void acc::ParallelOp::addPrivatization(
MLIRContext *context,
2360 mlir::acc::PrivateOp op,
2361 mlir::acc::PrivateRecipeOp recipe) {
2362 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2363 getPrivateOperandsMutable().append(op.getResult());
2366void acc::ParallelOp::addFirstPrivatization(
2367 MLIRContext *context, mlir::acc::FirstprivateOp op,
2368 mlir::acc::FirstprivateRecipeOp recipe) {
2369 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2370 getFirstprivateOperandsMutable().append(op.getResult());
2373void acc::ParallelOp::addReduction(
MLIRContext *context,
2374 mlir::acc::ReductionOp op,
2375 mlir::acc::ReductionRecipeOp recipe) {
2376 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2377 getReductionOperandsMutable().append(op.getResult());
2392 int32_t crtOperandsSize = operands.size();
2395 if (parser.parseOperand(operands.emplace_back()) ||
2396 parser.parseColonType(types.emplace_back()))
2401 seg.push_back(operands.size() - crtOperandsSize);
2411 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2412 parser.
getContext(), mlir::acc::DeviceType::None));
2418 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2425 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2426 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2427 p <<
" [" << attr <<
"]";
2432 std::optional<mlir::ArrayAttr> deviceTypes,
2433 std::optional<mlir::DenseI32ArrayAttr> segments) {
2435 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2437 llvm::interleaveComma(
2438 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2439 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2459 int32_t crtOperandsSize = operands.size();
2463 if (parser.parseOperand(operands.emplace_back()) ||
2464 parser.parseColonType(types.emplace_back()))
2470 seg.push_back(operands.size() - crtOperandsSize);
2480 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2481 parser.
getContext(), mlir::acc::DeviceType::None));
2487 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2496 std::optional<mlir::DenseI32ArrayAttr> segments) {
2498 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2500 llvm::interleaveComma(
2501 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2502 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2515 mlir::ArrayAttr &keywordOnly) {
2519 bool needCommaBeforeOperands =
false;
2523 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2524 parser.
getContext(), mlir::acc::DeviceType::None));
2525 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2532 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2539 needCommaBeforeOperands =
true;
2542 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2549 int32_t crtOperandsSize = operands.size();
2561 if (parser.parseOperand(operands.emplace_back()) ||
2562 parser.parseColonType(types.emplace_back()))
2568 seg.push_back(operands.size() - crtOperandsSize);
2578 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2579 parser.
getContext(), mlir::acc::DeviceType::None));
2586 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2587 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2589 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2597 if (attrs->size() != 1)
2599 if (
auto deviceTypeAttr =
2600 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2601 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2607 std::optional<mlir::ArrayAttr> deviceTypes,
2608 std::optional<mlir::DenseI32ArrayAttr> segments,
2609 std::optional<mlir::ArrayAttr> hasDevNum,
2610 std::optional<mlir::ArrayAttr> keywordOnly) {
2623 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2625 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2626 if (boolAttr && boolAttr.getValue())
2628 llvm::interleaveComma(
2629 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2630 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2647 if (parser.parseOperand(operands.emplace_back()) ||
2648 parser.parseColonType(types.emplace_back()))
2650 if (succeeded(parser.parseOptionalLSquare())) {
2651 if (parser.parseAttribute(attributes.emplace_back()) ||
2652 parser.parseRSquare())
2655 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2656 parser.getContext(), mlir::acc::DeviceType::None));
2663 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2670 std::optional<mlir::ArrayAttr> deviceTypes) {
2673 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2674 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2683 mlir::ArrayAttr &keywordOnlyDeviceType) {
2686 bool needCommaBeforeOperands =
false;
2690 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2691 parser.
getContext(), mlir::acc::DeviceType::None));
2692 keywordOnlyDeviceType =
2693 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2701 if (parser.parseAttribute(
2702 keywordOnlyDeviceTypeAttributes.emplace_back()))
2709 needCommaBeforeOperands =
true;
2712 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2717 if (parser.parseOperand(operands.emplace_back()) ||
2718 parser.parseColonType(types.emplace_back()))
2720 if (succeeded(parser.parseOptionalLSquare())) {
2721 if (parser.parseAttribute(attributes.emplace_back()) ||
2722 parser.parseRSquare())
2725 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2726 parser.getContext(), mlir::acc::DeviceType::None));
2732 if (
failed(parser.parseRParen()))
2737 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2744 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2746 if (operands.begin() == operands.end() &&
2762 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2763 mlir::Type &operandType, mlir::UnitAttr &attr) {
2766 attr = mlir::UnitAttr::get(parser.
getContext());
2776 if (failed(parser.
parseType(operandType)))
2786 std::optional<mlir::Value> operand,
2788 mlir::UnitAttr attr) {
2805 attr = mlir::UnitAttr::get(parser.
getContext());
2810 if (parser.parseOperand(operands.emplace_back()))
2818 if (parser.parseType(types.emplace_back()))
2833 mlir::UnitAttr attr) {
2838 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2840 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2846 mlir::acc::CombinedConstructsTypeAttr &attr) {
2848 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2849 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2851 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2852 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2854 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2855 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2858 "expected compute construct name");
2866 mlir::acc::CombinedConstructsTypeAttr attr) {
2868 switch (attr.getValue()) {
2869 case mlir::acc::CombinedConstructsType::KernelsLoop:
2872 case mlir::acc::CombinedConstructsType::ParallelLoop:
2875 case mlir::acc::CombinedConstructsType::SerialLoop:
2886unsigned SerialOp::getNumDataOperands() {
2887 return getReductionOperands().size() + getPrivateOperands().size() +
2888 getFirstprivateOperands().size() + getDataClauseOperands().size();
2891Value SerialOp::getDataOperand(
unsigned i) {
2893 numOptional += getIfCond() ? 1 : 0;
2894 numOptional += getSelfCond() ? 1 : 0;
2895 return getOperand(getWaitOperands().size() + numOptional + i);
2898bool acc::SerialOp::hasAsyncOnly() {
2899 return hasAsyncOnly(mlir::acc::DeviceType::None);
2902bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2907 return getAsyncValue(mlir::acc::DeviceType::None);
2910mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2915bool acc::SerialOp::hasWaitOnly() {
2916 return hasWaitOnly(mlir::acc::DeviceType::None);
2919bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2924 return getWaitValues(mlir::acc::DeviceType::None);
2928SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2930 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2931 getHasWaitDevnum(), deviceType);
2935 return getWaitDevnum(mlir::acc::DeviceType::None);
2938mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2940 getWaitOperandsSegments(), getHasWaitDevnum(),
2944LogicalResult acc::SerialOp::verify() {
2946 mlir::acc::PrivateRecipeOp>(
2947 *
this, getPrivateOperands(),
"private")))
2950 mlir::acc::FirstprivateRecipeOp>(
2951 *
this, getFirstprivateOperands(),
"firstprivate")))
2954 mlir::acc::ReductionRecipeOp>(
2955 *
this, getReductionOperands(),
"reduction")))
2959 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2960 getWaitOperandsDeviceTypeAttr(),
"wait")))
2964 getAsyncOperandsDeviceTypeAttr(),
2974void acc::SerialOp::addAsyncOnly(
2976 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2977 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2980void acc::SerialOp::addAsyncOperand(
2983 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2984 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2985 getAsyncOperandsMutable()));
2988void acc::SerialOp::addWaitOnly(
2990 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2991 effectiveDeviceTypes));
2993void acc::SerialOp::addWaitOperands(
2998 if (getWaitOperandsSegments())
2999 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3001 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3002 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3003 getWaitOperandsMutable(), segments));
3004 setWaitOperandsSegments(segments);
3007 if (getHasWaitDevnumAttr())
3008 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3011 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3013 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3016void acc::SerialOp::addPrivatization(
MLIRContext *context,
3017 mlir::acc::PrivateOp op,
3018 mlir::acc::PrivateRecipeOp recipe) {
3019 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3020 getPrivateOperandsMutable().append(op.getResult());
3023void acc::SerialOp::addFirstPrivatization(
3024 MLIRContext *context, mlir::acc::FirstprivateOp op,
3025 mlir::acc::FirstprivateRecipeOp recipe) {
3026 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3027 getFirstprivateOperandsMutable().append(op.getResult());
3030void acc::SerialOp::addReduction(
MLIRContext *context,
3031 mlir::acc::ReductionOp op,
3032 mlir::acc::ReductionRecipeOp recipe) {
3033 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3034 getReductionOperandsMutable().append(op.getResult());
3041unsigned KernelsOp::getNumDataOperands() {
3042 return getDataClauseOperands().size();
3045Value KernelsOp::getDataOperand(
unsigned i) {
3047 numOptional += getWaitOperands().size();
3048 numOptional += getNumGangs().size();
3049 numOptional += getNumWorkers().size();
3050 numOptional += getVectorLength().size();
3051 numOptional += getIfCond() ? 1 : 0;
3052 numOptional += getSelfCond() ? 1 : 0;
3053 return getOperand(numOptional + i);
3056bool acc::KernelsOp::hasAsyncOnly() {
3057 return hasAsyncOnly(mlir::acc::DeviceType::None);
3060bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3065 return getAsyncValue(mlir::acc::DeviceType::None);
3068mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3074 return getNumWorkersValue(mlir::acc::DeviceType::None);
3078acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
3083mlir::Value acc::KernelsOp::getVectorLengthValue() {
3084 return getVectorLengthValue(mlir::acc::DeviceType::None);
3088acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
3090 getVectorLength(), deviceType);
3094 return getNumGangsValues(mlir::acc::DeviceType::None);
3098KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
3100 getNumGangsSegments(), deviceType);
3103bool acc::KernelsOp::hasAnyGangWorkerVector(mlir::acc::DeviceType deviceType) {
3105 getNumGangsDeviceType(), getNumGangs(), getNumGangsSegments(),
3106 getNumWorkersDeviceType(), getNumWorkers(), getVectorLengthDeviceType(),
3107 getVectorLength(), deviceType);
3110bool acc::KernelsOp::hasWaitOnly() {
3111 return hasWaitOnly(mlir::acc::DeviceType::None);
3114bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3119 return getWaitValues(mlir::acc::DeviceType::None);
3123KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3125 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3126 getHasWaitDevnum(), deviceType);
3130 return getWaitDevnum(mlir::acc::DeviceType::None);
3133mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3135 getWaitOperandsSegments(), getHasWaitDevnum(),
3139LogicalResult acc::KernelsOp::verify() {
3141 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
3142 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
3146 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3147 getWaitOperandsDeviceTypeAttr(),
"wait")))
3151 getNumWorkersDeviceTypeAttr(),
3156 getVectorLengthDeviceTypeAttr(),
3161 getAsyncOperandsDeviceTypeAttr(),
3171void acc::KernelsOp::addPrivatization(
MLIRContext *context,
3172 mlir::acc::PrivateOp op,
3173 mlir::acc::PrivateRecipeOp recipe) {
3174 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3175 getPrivateOperandsMutable().append(op.getResult());
3178void acc::KernelsOp::addFirstPrivatization(
3179 MLIRContext *context, mlir::acc::FirstprivateOp op,
3180 mlir::acc::FirstprivateRecipeOp recipe) {
3181 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3182 getFirstprivateOperandsMutable().append(op.getResult());
3185void acc::KernelsOp::addReduction(
MLIRContext *context,
3186 mlir::acc::ReductionOp op,
3187 mlir::acc::ReductionRecipeOp recipe) {
3188 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3189 getReductionOperandsMutable().append(op.getResult());
3192void acc::KernelsOp::addNumWorkersOperand(
3195 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3196 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3197 getNumWorkersMutable()));
3200void acc::KernelsOp::addVectorLengthOperand(
3203 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3204 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3205 getVectorLengthMutable()));
3207void acc::KernelsOp::addAsyncOnly(
3209 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3210 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3213void acc::KernelsOp::addAsyncOperand(
3216 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3217 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3218 getAsyncOperandsMutable()));
3221void acc::KernelsOp::addNumGangsOperands(
3225 if (getNumGangsSegmentsAttr())
3226 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3228 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3229 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3230 getNumGangsMutable(), segments));
3232 setNumGangsSegments(segments);
3235void acc::KernelsOp::addWaitOnly(
3237 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3238 effectiveDeviceTypes));
3240void acc::KernelsOp::addWaitOperands(
3245 if (getWaitOperandsSegments())
3246 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3248 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3249 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3250 getWaitOperandsMutable(), segments));
3251 setWaitOperandsSegments(segments);
3254 if (getHasWaitDevnumAttr())
3255 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3258 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3260 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3267LogicalResult acc::HostDataOp::verify() {
3268 if (getDataClauseOperands().empty())
3269 return emitError(
"at least one operand must appear on the host_data "
3273 for (
mlir::Value operand : getDataClauseOperands()) {
3275 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3277 return emitError(
"expect data entry operation as defining op");
3280 if (!seenVars.insert(useDeviceOp.getVar()).second)
3281 return emitError(
"duplicate use_device variable");
3288 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3300 bool &needCommaBetweenValues,
bool &newValue) {
3307 attributes.push_back(gangArgType);
3308 needCommaBetweenValues =
true;
3319 mlir::ArrayAttr &gangOnlyDeviceType) {
3324 bool needCommaBetweenValues =
false;
3325 bool needCommaBeforeOperands =
false;
3329 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3330 parser.
getContext(), mlir::acc::DeviceType::None));
3331 gangOnlyDeviceType =
3332 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
3340 if (parser.parseAttribute(
3341 gangOnlyDeviceTypeAttributes.emplace_back()))
3348 needCommaBeforeOperands =
true;
3351 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3352 mlir::acc::GangArgType::Num);
3353 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3354 mlir::acc::GangArgType::Dim);
3355 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3356 parser.
getContext(), mlir::acc::GangArgType::Static);
3359 if (needCommaBeforeOperands) {
3360 needCommaBeforeOperands =
false;
3367 int32_t crtOperandsSize = gangOperands.size();
3369 bool newValue =
false;
3370 bool needValue =
false;
3371 if (needCommaBetweenValues) {
3379 gangOperands, gangOperandsType,
3380 gangArgTypeAttributes, argNum,
3381 needCommaBetweenValues, newValue)))
3384 gangOperands, gangOperandsType,
3385 gangArgTypeAttributes, argDim,
3386 needCommaBetweenValues, newValue)))
3388 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3389 gangOperands, gangOperandsType,
3390 gangArgTypeAttributes, argStatic,
3391 needCommaBetweenValues, newValue)))
3394 if (!newValue && needValue) {
3396 "new value expected after comma");
3404 if (gangOperands.empty())
3407 "expect at least one of num, dim or static values");
3413 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
3417 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3418 parser.
getContext(), mlir::acc::DeviceType::None));
3421 seg.push_back(gangOperands.size() - crtOperandsSize);
3429 gangArgTypeAttributes.end());
3430 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
3431 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
3434 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3435 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
3443 std::optional<mlir::ArrayAttr> gangArgTypes,
3444 std::optional<mlir::ArrayAttr> deviceTypes,
3445 std::optional<mlir::DenseI32ArrayAttr> segments,
3446 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3448 if (operands.begin() == operands.end() &&
3463 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
3465 llvm::interleaveComma(
3466 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
3467 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3468 (*gangArgTypes)[opIdx]);
3469 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3470 p << LoopOp::getGangNumKeyword();
3471 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3472 p << LoopOp::getGangDimKeyword();
3473 else if (gangArgTypeAttr.getValue() ==
3474 mlir::acc::GangArgType::Static)
3475 p << LoopOp::getGangStaticKeyword();
3476 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
3487 std::optional<mlir::ArrayAttr> segments,
3488 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3491 for (
auto attr : *segments) {
3492 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3493 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3501static std::optional<mlir::acc::DeviceType>
3503 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3505 return std::nullopt;
3506 for (
auto attr : deviceTypes) {
3507 auto deviceTypeAttr =
3508 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3509 if (!deviceTypeAttr)
3510 return mlir::acc::DeviceType::None;
3511 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3512 return deviceTypeAttr.getValue();
3514 return std::nullopt;
3517LogicalResult acc::LoopOp::verify() {
3518 if (getUpperbound().size() != getStep().size())
3519 return emitError() <<
"number of upperbounds expected to be the same as "
3522 if (getUpperbound().size() != getLowerbound().size())
3523 return emitError() <<
"number of upperbounds expected to be the same as "
3524 "number of lowerbounds";
3526 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3527 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3528 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3529 <<
" as upperbound size";
3532 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3533 return emitOpError() <<
"collapse device_type attr must be define when"
3534 <<
" collapse attr is present";
3536 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3537 getCollapseAttr().getValue().size() !=
3538 getCollapseDeviceTypeAttr().getValue().size())
3539 return emitOpError() <<
"collapse attribute count must match collapse"
3540 <<
" device_type count";
3541 if (
auto duplicateDeviceType =
checkDeviceTypes(getCollapseDeviceTypeAttr()))
3543 << acc::stringifyDeviceType(*duplicateDeviceType)
3544 <<
"` found in collapseDeviceType attribute";
3547 if (!getGangOperands().empty()) {
3548 if (!getGangOperandsArgType())
3549 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3550 <<
" when gang operands are present";
3552 if (getGangOperands().size() !=
3553 getGangOperandsArgTypeAttr().getValue().size())
3554 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3555 <<
" gangOperands count";
3557 if (getGangAttr()) {
3560 << acc::stringifyDeviceType(*duplicateDeviceType)
3561 <<
"` found in gang attribute";
3565 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3566 getGangOperandsDeviceTypeAttr(),
"gang")))
3572 << acc::stringifyDeviceType(*duplicateDeviceType)
3573 <<
"` found in worker attribute";
3574 if (
auto duplicateDeviceType =
3577 << acc::stringifyDeviceType(*duplicateDeviceType)
3578 <<
"` found in workerNumOperandsDeviceType attribute";
3580 getWorkerNumOperandsDeviceTypeAttr(),
3587 << acc::stringifyDeviceType(*duplicateDeviceType)
3588 <<
"` found in vector attribute";
3589 if (
auto duplicateDeviceType =
3592 << acc::stringifyDeviceType(*duplicateDeviceType)
3593 <<
"` found in vectorOperandsDeviceType attribute";
3595 getVectorOperandsDeviceTypeAttr(),
3600 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3601 getTileOperandsDeviceTypeAttr(),
"tile")))
3605 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3609 return emitError() <<
"only one of auto, independent, seq can be present "
3615 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3616 return attr.getValue() == mlir::acc::DeviceType::None;
3618 bool hasDefaultSeq =
3620 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3623 bool hasDefaultIndependent =
3624 getIndependentAttr()
3626 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3629 bool hasDefaultAuto =
3631 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3634 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3636 <<
"at least one of auto, independent, seq must be present";
3641 for (
auto attr : getSeqAttr()) {
3642 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3643 if (hasVector(deviceTypeAttr.getValue()) ||
3644 getVectorValue(deviceTypeAttr.getValue()) ||
3645 hasWorker(deviceTypeAttr.getValue()) ||
3646 getWorkerValue(deviceTypeAttr.getValue()) ||
3647 hasGang(deviceTypeAttr.getValue()) ||
3648 getGangValue(mlir::acc::GangArgType::Num,
3649 deviceTypeAttr.getValue()) ||
3650 getGangValue(mlir::acc::GangArgType::Dim,
3651 deviceTypeAttr.getValue()) ||
3652 getGangValue(mlir::acc::GangArgType::Static,
3653 deviceTypeAttr.getValue()))
3654 return emitError() <<
"gang, worker or vector cannot appear with seq";
3659 mlir::acc::PrivateRecipeOp>(
3660 *
this, getPrivateOperands(),
"private")))
3664 mlir::acc::FirstprivateRecipeOp>(
3665 *
this, getFirstprivateOperands(),
"firstprivate")))
3669 mlir::acc::ReductionRecipeOp>(
3670 *
this, getReductionOperands(),
"reduction")))
3673 if (getCombined().has_value() &&
3674 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3675 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3676 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3677 return emitError(
"unexpected combined constructs attribute");
3681 if (getRegion().empty())
3682 return emitError(
"expected non-empty body.");
3684 if (getUnstructured()) {
3685 if (!isContainerLike())
3687 "unstructured acc.loop must not have induction variables");
3688 }
else if (isContainerLike()) {
3692 uint64_t collapseCount = getCollapseValue().value_or(1);
3693 if (getCollapseAttr()) {
3694 for (
auto collapseEntry : getCollapseAttr()) {
3695 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3696 if (intAttr.getValue().getZExtValue() > collapseCount)
3697 collapseCount = intAttr.getValue().getZExtValue();
3705 bool foundSibling =
false;
3707 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3709 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3711 foundSibling =
true;
3716 expectedParent = op;
3719 if (collapseCount == 0)
3725 return emitError(
"found sibling loops inside container-like acc.loop");
3726 if (collapseCount != 0)
3727 return emitError(
"failed to find enough loop-like operations inside "
3728 "container-like acc.loop");
3734unsigned LoopOp::getNumDataOperands() {
3735 return getReductionOperands().size() + getPrivateOperands().size() +
3736 getFirstprivateOperands().size();
3739Value LoopOp::getDataOperand(
unsigned i) {
3740 unsigned numOptional =
3741 getLowerbound().size() + getUpperbound().size() + getStep().size();
3742 numOptional += getGangOperands().size();
3743 numOptional += getVectorOperands().size();
3744 numOptional += getWorkerNumOperands().size();
3745 numOptional += getTileOperands().size();
3746 numOptional += getCacheOperands().size();
3747 return getOperand(numOptional + i);
3750bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3752bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3756bool LoopOp::hasIndependent() {
3757 return hasIndependent(mlir::acc::DeviceType::None);
3760bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3764bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3766bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3771 return getVectorValue(mlir::acc::DeviceType::None);
3774mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3776 getVectorOperands(), deviceType);
3779bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3781bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3786 return getWorkerValue(mlir::acc::DeviceType::None);
3789mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3791 getWorkerNumOperands(), deviceType);
3794bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3796bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3801 return getTileValues(mlir::acc::DeviceType::None);
3805LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3807 getTileOperandsSegments(), deviceType);
3810std::optional<int64_t> LoopOp::getCollapseValue() {
3811 return getCollapseValue(mlir::acc::DeviceType::None);
3814std::optional<int64_t>
3815LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3816 if (!getCollapseAttr())
3817 return std::nullopt;
3818 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3820 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3821 return intAttr.getValue().getZExtValue();
3823 return std::nullopt;
3826mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3827 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3830mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3831 mlir::acc::DeviceType deviceType) {
3832 if (getGangOperands().empty())
3834 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3835 int32_t nbOperandsBefore = 0;
3836 for (
unsigned i = 0; i < *pos; ++i)
3837 nbOperandsBefore += (*getGangOperandsSegments())[i];
3840 .drop_front(nbOperandsBefore)
3841 .take_front((*getGangOperandsSegments())[*pos]);
3843 int32_t argTypeIdx = nbOperandsBefore;
3844 for (
auto value : values) {
3845 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3846 (*getGangOperandsArgType())[argTypeIdx]);
3847 if (gangArgTypeAttr.getValue() == gangArgType)
3855bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3857bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3862 return {&getRegion()};
3906 if (!regionArgs.empty()) {
3907 p << acc::LoopOp::getControlKeyword() <<
"(";
3908 llvm::interleaveComma(regionArgs, p,
3910 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3911 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3912 <<
" : " << stepType <<
") ";
3919 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3920 effectiveDeviceTypes));
3923void acc::LoopOp::addIndependent(
3925 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3926 context, getIndependentAttr(), effectiveDeviceTypes));
3931 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3932 effectiveDeviceTypes));
3935void acc::LoopOp::setCollapseForDeviceTypes(
3937 llvm::APInt value) {
3941 assert((getCollapseAttr() ==
nullptr) ==
3942 (getCollapseDeviceTypeAttr() ==
nullptr));
3943 assert(value.getBitWidth() == 64);
3945 if (getCollapseAttr()) {
3946 for (
const auto &existing :
3947 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3948 newValues.push_back(std::get<0>(existing));
3949 newDeviceTypes.push_back(std::get<1>(existing));
3953 if (effectiveDeviceTypes.empty()) {
3956 newValues.push_back(
3957 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3958 newDeviceTypes.push_back(
3959 acc::DeviceTypeAttr::get(context, DeviceType::None));
3961 for (DeviceType dt : effectiveDeviceTypes) {
3962 newValues.push_back(
3963 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3964 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3968 setCollapseAttr(ArrayAttr::get(context, newValues));
3969 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3972void acc::LoopOp::setTileForDeviceTypes(
3976 if (getTileOperandsSegments())
3977 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3979 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3980 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3981 getTileOperandsMutable(), segments));
3983 setTileOperandsSegments(segments);
3986void acc::LoopOp::addVectorOperand(
3989 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3990 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3991 newValue, getVectorOperandsMutable()));
3994void acc::LoopOp::addEmptyVector(
3996 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3997 effectiveDeviceTypes));
4000void acc::LoopOp::addWorkerNumOperand(
4003 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4004 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
4005 newValue, getWorkerNumOperandsMutable()));
4008void acc::LoopOp::addEmptyWorker(
4010 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4011 effectiveDeviceTypes));
4014void acc::LoopOp::addEmptyGang(
4016 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4017 effectiveDeviceTypes));
4020bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
4021 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
4022 return attr.getValue() == dt;
4024 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
4025 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
4028 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
4030 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
4032 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
4038bool acc::LoopOp::hasDefaultGangWorkerVector() {
4039 return hasAnyGangWorkerVector(DeviceType::None);
4042bool acc::LoopOp::hasAnyGangWorkerVector(DeviceType deviceType) {
4043 return hasVector(deviceType) || getVectorValue(deviceType) ||
4044 hasWorker(deviceType) || getWorkerValue(deviceType) ||
4045 hasGang(deviceType) || getGangValue(GangArgType::Num, deviceType) ||
4046 getGangValue(GangArgType::Dim, deviceType) ||
4047 getGangValue(GangArgType::Static, deviceType);
4051acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
4052 if (hasSeq(deviceType))
4053 return LoopParMode::loop_seq;
4054 if (hasAuto(deviceType))
4055 return LoopParMode::loop_auto;
4056 if (hasIndependent(deviceType))
4057 return LoopParMode::loop_independent;
4059 return LoopParMode::loop_seq;
4061 return LoopParMode::loop_auto;
4062 assert(hasIndependent() &&
4063 "loop must have default auto, seq, or independent");
4064 return LoopParMode::loop_independent;
4067void acc::LoopOp::addGangOperands(
4072 getGangOperandsSegments())
4073 llvm::copy(*existingSegments, std::back_inserter(segments));
4075 unsigned beforeCount = segments.size();
4077 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4078 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
4079 getGangOperandsMutable(), segments));
4081 setGangOperandsSegments(segments);
4088 unsigned numAdded = segments.size() - beforeCount;
4092 if (getGangOperandsArgTypeAttr())
4093 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
4095 for (
auto i : llvm::index_range(0u, numAdded)) {
4096 llvm::transform(argTypes, std::back_inserter(gangTypes),
4097 [=](mlir::acc::GangArgType gangTy) {
4098 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
4103 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
4107void acc::LoopOp::addPrivatization(
MLIRContext *context,
4108 mlir::acc::PrivateOp op,
4109 mlir::acc::PrivateRecipeOp recipe) {
4110 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4111 getPrivateOperandsMutable().append(op.getResult());
4114void acc::LoopOp::addFirstPrivatization(
4115 MLIRContext *context, mlir::acc::FirstprivateOp op,
4116 mlir::acc::FirstprivateRecipeOp recipe) {
4117 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4118 getFirstprivateOperandsMutable().append(op.getResult());
4121void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
4122 mlir::acc::ReductionRecipeOp recipe) {
4123 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4124 getReductionOperandsMutable().append(op.getResult());
4131LogicalResult acc::DataOp::verify() {
4136 return emitError(
"at least one operand or the default attribute "
4137 "must appear on the data operation");
4139 for (
mlir::Value operand : getDataClauseOperands())
4140 if (isa<BlockArgument>(operand) ||
4141 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4142 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
4143 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
4144 operand.getDefiningOp()))
4145 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4154unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
4156Value DataOp::getDataOperand(
unsigned i) {
4157 unsigned numOptional = getIfCond() ? 1 : 0;
4159 numOptional += getWaitOperands().size();
4160 return getOperand(numOptional + i);
4163bool acc::DataOp::hasAsyncOnly() {
4164 return hasAsyncOnly(mlir::acc::DeviceType::None);
4167bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4172 return getAsyncValue(mlir::acc::DeviceType::None);
4175mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4180bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
4182bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4187 return getWaitValues(mlir::acc::DeviceType::None);
4191DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4193 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4194 getHasWaitDevnum(), deviceType);
4198 return getWaitDevnum(mlir::acc::DeviceType::None);
4201mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4203 getWaitOperandsSegments(), getHasWaitDevnum(),
4207void acc::DataOp::addAsyncOnly(
4209 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4210 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4213void acc::DataOp::addAsyncOperand(
4216 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4217 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4218 getAsyncOperandsMutable()));
4221void acc::DataOp::addWaitOnly(
MLIRContext *context,
4223 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4224 effectiveDeviceTypes));
4227void acc::DataOp::addWaitOperands(
4232 if (getWaitOperandsSegments())
4233 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4235 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4236 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4237 getWaitOperandsMutable(), segments));
4238 setWaitOperandsSegments(segments);
4241 if (getHasWaitDevnumAttr())
4242 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4245 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4247 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4254LogicalResult acc::ExitDataOp::verify() {
4258 if (getDataClauseOperands().empty())
4259 return emitError(
"at least one operand must be present in dataOperands on "
4260 "the exit data operation");
4264 if (getAsyncOperand() && getAsync())
4265 return emitError(
"async attribute cannot appear with asyncOperand");
4269 if (!getWaitOperands().empty() && getWait())
4270 return emitError(
"wait attribute cannot appear with waitOperands");
4272 if (getWaitDevnum() && getWaitOperands().empty())
4273 return emitError(
"wait_devnum cannot appear without waitOperands");
4278unsigned ExitDataOp::getNumDataOperands() {
4279 return getDataClauseOperands().size();
4282Value ExitDataOp::getDataOperand(
unsigned i) {
4283 unsigned numOptional = getIfCond() ? 1 : 0;
4284 numOptional += getAsyncOperand() ? 1 : 0;
4285 numOptional += getWaitDevnum() ? 1 : 0;
4286 return getOperand(getWaitOperands().size() + numOptional + i);
4291 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
4294void ExitDataOp::addAsyncOnly(
MLIRContext *context,
4296 assert(effectiveDeviceTypes.empty());
4297 assert(!getAsyncAttr());
4298 assert(!getAsyncOperand());
4300 setAsyncAttr(mlir::UnitAttr::get(context));
4303void ExitDataOp::addAsyncOperand(
4306 assert(effectiveDeviceTypes.empty());
4307 assert(!getAsyncAttr());
4308 assert(!getAsyncOperand());
4310 getAsyncOperandMutable().append(newValue);
4315 assert(effectiveDeviceTypes.empty());
4316 assert(!getWaitAttr());
4317 assert(getWaitOperands().empty());
4318 assert(!getWaitDevnum());
4320 setWaitAttr(mlir::UnitAttr::get(context));
4323void ExitDataOp::addWaitOperands(
4326 assert(effectiveDeviceTypes.empty());
4327 assert(!getWaitAttr());
4328 assert(getWaitOperands().empty());
4329 assert(!getWaitDevnum());
4334 getWaitDevnumMutable().append(newValues.front());
4335 newValues = newValues.drop_front();
4338 getWaitOperandsMutable().append(newValues);
4345LogicalResult acc::EnterDataOp::verify() {
4349 if (getDataClauseOperands().empty())
4350 return emitError(
"at least one operand must be present in dataOperands on "
4351 "the enter data operation");
4355 if (getAsyncOperand() && getAsync())
4356 return emitError(
"async attribute cannot appear with asyncOperand");
4360 if (!getWaitOperands().empty() && getWait())
4361 return emitError(
"wait attribute cannot appear with waitOperands");
4363 if (getWaitDevnum() && getWaitOperands().empty())
4364 return emitError(
"wait_devnum cannot appear without waitOperands");
4366 for (
mlir::Value operand : getDataClauseOperands())
4367 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4368 operand.getDefiningOp()))
4369 return emitError(
"expect data entry operation as defining op");
4374unsigned EnterDataOp::getNumDataOperands() {
4375 return getDataClauseOperands().size();
4378Value EnterDataOp::getDataOperand(
unsigned i) {
4379 unsigned numOptional = getIfCond() ? 1 : 0;
4380 numOptional += getAsyncOperand() ? 1 : 0;
4381 numOptional += getWaitDevnum() ? 1 : 0;
4382 return getOperand(getWaitOperands().size() + numOptional + i);
4387 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
4390void EnterDataOp::addAsyncOnly(
4392 assert(effectiveDeviceTypes.empty());
4393 assert(!getAsyncAttr());
4394 assert(!getAsyncOperand());
4396 setAsyncAttr(mlir::UnitAttr::get(context));
4399void EnterDataOp::addAsyncOperand(
4402 assert(effectiveDeviceTypes.empty());
4403 assert(!getAsyncAttr());
4404 assert(!getAsyncOperand());
4406 getAsyncOperandMutable().append(newValue);
4409void EnterDataOp::addWaitOnly(
MLIRContext *context,
4411 assert(effectiveDeviceTypes.empty());
4412 assert(!getWaitAttr());
4413 assert(getWaitOperands().empty());
4414 assert(!getWaitDevnum());
4416 setWaitAttr(mlir::UnitAttr::get(context));
4419void EnterDataOp::addWaitOperands(
4422 assert(effectiveDeviceTypes.empty());
4423 assert(!getWaitAttr());
4424 assert(getWaitOperands().empty());
4425 assert(!getWaitDevnum());
4430 getWaitDevnumMutable().append(newValues.front());
4431 newValues = newValues.drop_front();
4434 getWaitOperandsMutable().append(newValues);
4441LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
4447LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
4453LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4460 if (
Value writeVal = op.getWriteOpVal()) {
4469LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
4471LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4477AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4478 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4480 return dyn_cast<AtomicReadOp>(getSecondOp());
4483AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4484 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4486 return dyn_cast<AtomicWriteOp>(getSecondOp());
4489AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4490 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4492 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4495LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
4501template <
typename Op>
4504 bool requireAtLeastOneOperand =
true) {
4505 if (operands.empty() && requireAtLeastOneOperand)
4508 "at least one operand must appear on the declare operation");
4511 if (isa<BlockArgument>(operand) ||
4512 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4513 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4514 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4515 operand.getDefiningOp()))
4517 "expect valid declare data entry operation or acc.getdeviceptr "
4521 assert(var &&
"declare operands can only be data entry operations which "
4524 std::optional<mlir::acc::DataClause> dataClauseOptional{
4526 assert(dataClauseOptional.has_value() &&
4527 "declare operands can only be data entry operations which must have "
4529 (
void)dataClauseOptional;
4535LogicalResult acc::DeclareEnterOp::verify() {
4543LogicalResult acc::DeclareExitOp::verify() {
4554LogicalResult acc::DeclareOp::verify() {
4563 acc::DeviceType dtype) {
4564 unsigned parallelism = 0;
4565 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4566 parallelism += op.hasWorker(dtype) ? 1 : 0;
4567 parallelism += op.hasVector(dtype) ? 1 : 0;
4568 parallelism += op.hasSeq(dtype) ? 1 : 0;
4572LogicalResult acc::RoutineOp::verify() {
4573 unsigned baseParallelism =
4576 if (baseParallelism > 1)
4577 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4578 "be present at the same time";
4580 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4582 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4583 if (dtype == acc::DeviceType::None)
4587 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4588 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4589 "be present at the same time for device_type `"
4590 << acc::stringifyDeviceType(dtype) <<
"`";
4597 mlir::ArrayAttr &bindIdName,
4598 mlir::ArrayAttr &bindStrName,
4599 mlir::ArrayAttr &deviceIdTypes,
4600 mlir::ArrayAttr &deviceStrTypes) {
4607 mlir::Attribute newAttr;
4608 bool isSymbolRefAttr;
4609 auto parseResult = parser.parseAttribute(newAttr);
4610 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4611 bindIdNameAttrs.push_back(symbolRefAttr);
4612 isSymbolRefAttr = true;
4613 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4614 bindStrNameAttrs.push_back(stringAttr);
4615 isSymbolRefAttr =
false;
4620 if (isSymbolRefAttr) {
4621 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4622 parser.getContext(), mlir::acc::DeviceType::None));
4624 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4625 parser.getContext(), mlir::acc::DeviceType::None));
4628 if (isSymbolRefAttr) {
4629 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4630 parser.parseRSquare())
4633 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4634 parser.parseRSquare())
4642 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4643 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4644 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4645 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4651 std::optional<mlir::ArrayAttr> bindIdName,
4652 std::optional<mlir::ArrayAttr> bindStrName,
4653 std::optional<mlir::ArrayAttr> deviceIdTypes,
4654 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4661 allBindNames.append(bindIdName->begin(), bindIdName->end());
4662 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4667 allBindNames.append(bindStrName->begin(), bindStrName->end());
4668 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4672 if (!allBindNames.empty())
4673 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4674 [&](
const auto &pair) {
4675 p << std::get<0>(pair);
4681 mlir::ArrayAttr &gang,
4682 mlir::ArrayAttr &gangDim,
4683 mlir::ArrayAttr &gangDimDeviceTypes) {
4686 gangDimDeviceTypeAttrs;
4687 bool needCommaBeforeOperands =
false;
4691 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4692 parser.
getContext(), mlir::acc::DeviceType::None));
4693 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4700 if (parser.parseAttribute(gangAttrs.emplace_back()))
4707 needCommaBeforeOperands =
true;
4710 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4714 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4715 parser.parseColon() ||
4716 parser.parseAttribute(gangDimAttrs.emplace_back()))
4718 if (succeeded(parser.parseOptionalLSquare())) {
4719 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4720 parser.parseRSquare())
4723 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4724 parser.getContext(), mlir::acc::DeviceType::None));
4730 if (
failed(parser.parseRParen()))
4733 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4734 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4735 gangDimDeviceTypes =
4736 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4742 std::optional<mlir::ArrayAttr> gang,
4743 std::optional<mlir::ArrayAttr> gangDim,
4744 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4747 gang->size() == 1) {
4748 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4749 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4761 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4762 [&](
const auto &pair) {
4763 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4764 p << std::get<0>(pair);
4772 mlir::ArrayAttr &deviceTypes) {
4776 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4777 parser.
getContext(), mlir::acc::DeviceType::None));
4778 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4785 if (parser.parseAttribute(attributes.emplace_back()))
4793 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4799 std::optional<mlir::ArrayAttr> deviceTypes) {
4802 auto deviceTypeAttr =
4803 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4804 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4813 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4819bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4821bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4825bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4827bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4831bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4833bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4837std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4838RoutineOp::getBindNameValue() {
4839 return getBindNameValue(mlir::acc::DeviceType::None);
4842std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4843RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4845 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4846 auto attr = (*getBindIdName())[*pos];
4847 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4848 assert(symbolRefAttr &&
"expected SymbolRef");
4849 return symbolRefAttr;
4854 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4855 auto attr = (*getBindStrName())[*pos];
4856 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4857 assert(stringAttr &&
"expected String");
4862 return std::nullopt;
4865bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4867bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4871std::optional<int64_t> RoutineOp::getGangDimValue() {
4872 return getGangDimValue(mlir::acc::DeviceType::None);
4875std::optional<int64_t>
4876RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4878 return std::nullopt;
4879 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4880 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4881 return intAttr.getInt();
4883 return std::nullopt;
4888 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4889 effectiveDeviceTypes));
4894 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4895 effectiveDeviceTypes));
4900 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4901 effectiveDeviceTypes));
4906 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4907 effectiveDeviceTypes));
4916 if (getGangDimAttr())
4917 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4918 if (getGangDimDeviceTypeAttr())
4919 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4921 assert(dimValues.size() == deviceTypes.size());
4923 if (effectiveDeviceTypes.empty()) {
4924 dimValues.push_back(
4925 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4926 deviceTypes.push_back(
4927 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4929 for (DeviceType dt : effectiveDeviceTypes) {
4930 dimValues.push_back(
4931 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4932 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4935 assert(dimValues.size() == deviceTypes.size());
4937 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4938 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4941void RoutineOp::addBindStrName(
MLIRContext *context,
4943 mlir::StringAttr val) {
4944 unsigned before = getBindStrNameDeviceTypeAttr()
4945 ? getBindStrNameDeviceTypeAttr().size()
4948 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4949 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4950 unsigned after = getBindStrNameDeviceTypeAttr().size();
4953 if (getBindStrNameAttr())
4954 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4955 for (
unsigned i = 0; i < after - before; ++i)
4956 vals.push_back(val);
4958 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4961void RoutineOp::addBindIDName(
MLIRContext *context,
4963 mlir::SymbolRefAttr val) {
4965 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4967 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4968 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4969 unsigned after = getBindIdNameDeviceTypeAttr().size();
4972 if (getBindIdNameAttr())
4973 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4974 for (
unsigned i = 0; i < after - before; ++i)
4975 vals.push_back(val);
4977 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4984LogicalResult acc::InitOp::verify() {
4985 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4986 return emitOpError(
"cannot be nested in a compute operation");
4990void acc::InitOp::addDeviceType(
MLIRContext *context,
4991 mlir::acc::DeviceType deviceType) {
4993 if (getDeviceTypesAttr())
4994 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4996 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4997 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
5004LogicalResult acc::ShutdownOp::verify() {
5005 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
5006 return emitOpError(
"cannot be nested in a compute operation");
5010void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
5011 mlir::acc::DeviceType deviceType) {
5013 if (getDeviceTypesAttr())
5014 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
5016 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
5017 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
5024LogicalResult acc::SetOp::verify() {
5025 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
5026 return emitOpError(
"cannot be nested in a compute operation");
5027 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
5028 return emitOpError(
"at least one default_async, device_num, or device_type "
5029 "operand must appear");
5037LogicalResult acc::UpdateOp::verify() {
5039 if (getDataClauseOperands().empty())
5040 return emitError(
"at least one value must be present in dataOperands");
5043 getAsyncOperandsDeviceTypeAttr(),
5048 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
5049 getWaitOperandsDeviceTypeAttr(),
"wait")))
5055 for (
mlir::Value operand : getDataClauseOperands())
5056 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
5057 operand.getDefiningOp()))
5058 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
5064unsigned UpdateOp::getNumDataOperands() {
5065 return getDataClauseOperands().size();
5068Value UpdateOp::getDataOperand(
unsigned i) {
5070 numOptional += getIfCond() ? 1 : 0;
5071 return getOperand(getWaitOperands().size() + numOptional + i);
5076 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
5079bool UpdateOp::hasAsyncOnly() {
5080 return hasAsyncOnly(mlir::acc::DeviceType::None);
5083bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
5088 return getAsyncValue(mlir::acc::DeviceType::None);
5091mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
5101bool UpdateOp::hasWaitOnly() {
5102 return hasWaitOnly(mlir::acc::DeviceType::None);
5105bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
5110 return getWaitValues(mlir::acc::DeviceType::None);
5114UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
5116 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
5117 getHasWaitDevnum(), deviceType);
5121 return getWaitDevnum(mlir::acc::DeviceType::None);
5124mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
5126 getWaitOperandsSegments(), getHasWaitDevnum(),
5132 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
5133 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
5136void UpdateOp::addAsyncOperand(
5139 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5140 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
5141 getAsyncOperandsMutable()));
5146 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
5147 effectiveDeviceTypes));
5150void UpdateOp::addWaitOperands(
5155 if (getWaitOperandsSegments())
5156 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
5158 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5159 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
5160 getWaitOperandsMutable(), segments));
5161 setWaitOperandsSegments(segments);
5164 if (getHasWaitDevnumAttr())
5165 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5168 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
5170 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5177LogicalResult acc::WaitOp::verify() {
5180 if (getAsyncOperand() && getAsync())
5181 return emitError(
"async attribute cannot appear with asyncOperand");
5183 if (getWaitDevnum() && getWaitOperands().empty())
5184 return emitError(
"wait_devnum cannot appear without waitOperands");
5189#define GET_OP_CLASSES
5190#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5192#define GET_ATTRDEF_CLASSES
5193#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5195#define GET_TYPEDEF_CLASSES
5196#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5207 .Case<ACC_DATA_ENTRY_OPS>(
5208 [&](
auto entry) {
return entry.getVarPtr(); })
5209 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5210 [&](
auto exit) {
return exit.getVarPtr(); })
5228 [&](
auto entry) {
return entry.getVarType(); })
5229 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5230 [&](
auto exit) {
return exit.getVarType(); })
5240 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5241 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
5251 [&](
auto dataClause) {
return dataClause.getAccVar(); })
5260 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
5270 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5272 dataClause.getBounds().begin(), dataClause.getBounds().end());
5284 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5286 dataClause.getAsyncOperands().begin(),
5287 dataClause.getAsyncOperands().end());
5298 return dataClause.getAsyncOperandsDeviceTypeAttr();
5306 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
5313 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
5320std::optional<mlir::acc::DataClause>
5325 .Case<ACC_DATA_ENTRY_OPS>(
5326 [&](
auto entry) {
return entry.getDataClause(); })
5334 [&](
auto entry) {
return entry.getImplicit(); })
5343 [&](
auto entry) {
return entry.getDataClauseOperands(); })
5345 return dataOperands;
5353 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
5355 return dataOperands;
5362 [&](
auto entry) {
return entry.getRecipeAttr(); })
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasAnyGangWorkerVectorForDeviceType(std::optional< mlir::ArrayAttr > numGangsDeviceType, mlir::Operation::operand_range numGangs, std::optional< llvm::ArrayRef< int32_t > > numGangsSegments, std::optional< mlir::ArrayAttr > numWorkersDeviceType, mlir::Operation::operand_range numWorkers, std::optional< mlir::ArrayAttr > vectorLengthDeviceType, mlir::Operation::operand_range vectorLength, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then continue after the...
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
unsigned size() const
Returns the current size of the range.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
bool isOperation() const
Return true if the successor is an operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static CurrentDeviceIdResource * get()
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Base attribute class for language-specific variable information carried through the OpenACC type inte...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
#define ACC_COMPUTE_CONSTRUCT_OPS
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.