26#include "llvm/ADT/SmallSet.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/LogicalResult.h"
34#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
38#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
42static bool isScalarLikeType(
Type type) {
50 if (!varName.empty()) {
51 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
57struct MemRefPointerLikeModel
58 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
60 return cast<T>(pointer).getElementType();
63 mlir::acc::VariableTypeCategory
66 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
67 return mappableTy.getTypeCategory(varPtr);
69 auto memrefTy = cast<T>(pointer);
70 if (!memrefTy.hasRank()) {
73 return mlir::acc::VariableTypeCategory::uncategorized;
76 if (memrefTy.getRank() == 0) {
77 if (isScalarLikeType(memrefTy.getElementType())) {
78 return mlir::acc::VariableTypeCategory::scalar;
82 return mlir::acc::VariableTypeCategory::uncategorized;
86 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
87 return mlir::acc::VariableTypeCategory::array;
90 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
91 StringRef varName, Type varType, Value originalVar,
92 bool &needsFree)
const {
93 auto memrefTy = cast<MemRefType>(pointer);
97 if (memrefTy.hasStaticShape()) {
99 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
100 attachVarNameAttr(allocaOp, builder, varName);
101 return allocaOp.getResult();
106 if (originalVar && originalVar.
getType() == memrefTy &&
107 memrefTy.hasRank()) {
108 SmallVector<Value> dynamicSizes;
109 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
110 if (memrefTy.isDynamicDim(i)) {
114 memref::DimOp::create(builder, loc, originalVar, indexValue);
115 dynamicSizes.push_back(dimSize);
122 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
123 attachVarNameAttr(allocOp, builder, varName);
124 return allocOp.getResult();
131 bool genFree(Type pointer, OpBuilder &builder, Location loc,
133 Type varType)
const {
136 Value valueToInspect = allocRes ? allocRes : memrefValue;
139 Value currentValue = valueToInspect;
140 Operation *originalAlloc =
nullptr;
144 while (currentValue) {
147 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
148 originalAlloc = definingOp;
153 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
154 currentValue = castOp.getSource();
159 if (
auto reinterpretCastOp =
160 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
161 currentValue = reinterpretCastOp.getSource();
173 if (isa<memref::AllocaOp>(originalAlloc)) {
177 if (isa<memref::AllocOp>(originalAlloc)) {
179 memref::DeallocOp::create(builder, loc, memrefValue);
188 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
192 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
193 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
199 if (destMemref && srcMemref &&
200 destMemref.getType().getElementType() ==
201 srcMemref.getType().getElementType() &&
202 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
203 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
210 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
212 Type valueType)
const {
217 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
221 auto memrefTy = memrefValue.
getType();
224 if (memrefTy.getRank() != 0)
227 return memref::LoadOp::create(builder, loc, memrefValue);
230 bool genStore(Type pointer, OpBuilder &builder, Location loc,
236 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
240 auto memrefTy = memrefValue.getType();
243 if (memrefTy.getRank() != 0)
246 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
250 Value
genCast(Type, OpBuilder &builder, Location loc, Value value,
251 Type resultType)
const {
252 if (value.
getType() == resultType)
255 if (isa<BaseMemRefType>(value.
getType()) &&
256 isa<BaseMemRefType>(resultType)) {
259 return memref::CastOp::create(builder, loc, resultType, value);
260 if (memref::MemorySpaceCastOp::areCastCompatible(
262 return memref::MemorySpaceCastOp::create(builder, loc, resultType,
269 if (
auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
270 if (!isa<BaseMemRefType>(resPtrLike))
271 if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
273 if (
auto valPtrLike = dyn_cast<PointerLikeType>(value.
getType()))
274 if (!isa<BaseMemRefType>(valPtrLike))
275 if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
281 bool isDeviceData(Type pointer, Value var)
const {
282 auto memrefTy = cast<T>(pointer);
283 Attribute memSpace = memrefTy.getMemorySpace();
284 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
288struct LLVMPointerPointerLikeModel
289 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
290 LLVM::LLVMPointerType> {
293 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
295 Type valueType)
const {
300 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
303 bool genStore(Type pointer, OpBuilder &builder, Location loc,
305 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
309 Value
genCast(Type, OpBuilder &builder, Location loc, Value value,
310 Type resultType)
const {
311 if (value.
getType() == resultType)
314 auto srcPtrTy = dyn_cast<LLVM::LLVMPointerType>(value.
getType());
315 auto dstPtrTy = dyn_cast<LLVM::LLVMPointerType>(resultType);
316 if (srcPtrTy && dstPtrTy) {
317 if (srcPtrTy.getAddressSpace() != dstPtrTy.getAddressSpace())
318 return LLVM::AddrSpaceCastOp::create(builder, loc, resultType, value);
322 if (srcPtrTy && isa<IntegerType>(resultType))
323 return LLVM::PtrToIntOp::create(builder, loc, resultType, value);
326 Value intVal = value;
327 if (isa<IndexType>(value.
getType()))
328 intVal = arith::IndexCastUIOp::create(builder, loc,
330 if (isa<IntegerType>(intVal.
getType()))
331 return LLVM::IntToPtrOp::create(builder, loc, resultType, intVal);
334 if (
auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
335 if (!isa<LLVM::LLVMPointerType>(resPtrLike))
336 if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
338 if (
auto valPtrLike = dyn_cast<PointerLikeType>(value.
getType()))
339 if (!isa<LLVM::LLVMPointerType>(valPtrLike))
340 if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
343 return UnrealizedConversionCastOp::create(builder, loc,
349struct MemrefAddressOfGlobalModel
350 :
public AddressOfGlobalOpInterface::ExternalModel<
351 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
352 SymbolRefAttr getSymbol(Operation *op)
const {
353 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
354 return getGlobalOp.getNameAttr();
358struct MemrefGlobalVariableModel
359 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
361 bool isConstant(Operation *op)
const {
362 auto globalOp = cast<memref::GlobalOp>(op);
363 return globalOp.getConstant();
366 Region *getInitRegion(Operation *op)
const {
371 bool isDeviceData(Operation *op)
const {
372 auto globalOp = cast<memref::GlobalOp>(op);
373 Attribute memSpace = globalOp.getType().getMemorySpace();
374 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
378struct GPULaunchOffloadRegionModel
379 :
public acc::OffloadRegionOpInterface::ExternalModel<
380 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
381 mlir::Region &getOffloadRegion(mlir::Operation *op)
const {
382 return cast<gpu::LaunchOp>(op).getBody();
390mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
391 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
394 if (existingDeviceTypes)
395 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
397 if (newDeviceTypes.empty())
398 deviceTypes.push_back(
399 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
401 for (DeviceType dt : newDeviceTypes)
402 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
404 return mlir::ArrayAttr::get(context, deviceTypes);
413mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
414 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
419 if (existingDeviceTypes)
420 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
422 if (newDeviceTypes.empty()) {
423 argCollection.
append(arguments);
424 segments.push_back(arguments.size());
425 deviceTypes.push_back(
426 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
429 for (DeviceType dt : newDeviceTypes) {
430 argCollection.
append(arguments);
431 segments.push_back(arguments.size());
432 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
435 return mlir::ArrayAttr::get(context, deviceTypes);
439mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
440 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
444 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
445 newDeviceTypes, arguments,
446 argCollection, segments);
454void OpenACCDialect::initialize() {
457#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
460#define GET_ATTRDEF_LIST
461#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
464#define GET_TYPEDEF_LIST
465#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
471 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
473 UnrankedMemRefType::attachInterface<
474 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
475 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
479 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
481 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
482 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*
getContext());
519void ParallelOp::getSuccessorRegions(
549void HostDataOp::getSuccessorRegions(
564 if (getUnstructured()) {
597 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
601 mlir::acc::DeviceType deviceType) {
605 for (
auto attr : *arrayAttr) {
606 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
607 if (deviceTypeAttr.getValue() == deviceType)
615 std::optional<mlir::ArrayAttr> deviceTypes) {
620 llvm::interleaveComma(*deviceTypes, p,
626 mlir::acc::DeviceType deviceType) {
627 unsigned segmentIdx = 0;
628 for (
auto attr : segments) {
629 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
630 if (deviceTypeAttr.getValue() == deviceType)
631 return std::make_optional(segmentIdx);
641 mlir::acc::DeviceType deviceType) {
643 return range.take_front(0);
644 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
645 int32_t nbOperandsBefore = 0;
646 for (
unsigned i = 0; i < *pos; ++i)
647 nbOperandsBefore += (*segments)[i];
648 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
650 return range.take_front(0);
657 std::optional<mlir::ArrayAttr> hasWaitDevnum,
658 mlir::acc::DeviceType deviceType) {
661 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
662 if (hasWaitDevnum->getValue()[*pos])
673 std::optional<mlir::ArrayAttr> hasWaitDevnum,
674 mlir::acc::DeviceType deviceType) {
679 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
680 if (hasWaitDevnum && *hasWaitDevnum) {
681 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
682 if (boolAttr.getValue())
683 return range.drop_front(1);
689template <
typename Op>
691 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
693 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
698 op.hasAsyncOnly(dtype))
700 "asyncOnly attribute cannot appear with asyncOperand");
705 op.hasWaitOnly(dtype))
706 return op.
emitError(
"wait attribute cannot appear with waitOperands");
711template <
typename Op>
714 return op.
emitError(
"must have var operand");
717 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
718 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
719 return op.
emitError(
"var must be mappable or pointer-like");
722 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
723 op.getVarType() == op.getVar().getType())
724 return op.
emitError(
"varType must capture the element type of var");
729template <
typename Op>
731 if (op.getVar().getType() != op.getAccVar().getType())
732 return op.
emitError(
"input and output types must match");
737template <
typename Op>
739 if (op.getModifiers() != acc::DataClauseModifier::none)
740 return op.
emitError(
"no data clause modifiers are allowed");
744template <
typename Op>
747 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
749 "invalid data clause modifiers: " +
750 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
755template <
typename OpT,
typename RecipeOpT>
756static LogicalResult
checkRecipe(OpT op, llvm::StringRef operandName) {
761 !std::is_same_v<OpT, acc::ReductionOp>)
764 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
766 return op->emitOpError() <<
"recipe expected for " << operandName;
771 return op->emitOpError()
772 <<
"expected symbol reference " << operandRecipe <<
" to point to a "
773 << operandName <<
" declaration";
794 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
815 if (failed(parser.
parseType(accVarType)))
825 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
837 mlir::TypeAttr &varTypeAttr) {
838 if (failed(parser.
parseType(varPtrType)))
849 varTypeAttr = mlir::TypeAttr::get(varType);
854 if (
auto ptrTy = dyn_cast<acc::PointerLikeType>(varPtrType)) {
855 Type elementType = ptrTy.getElementType();
858 varTypeAttr = mlir::TypeAttr::get(elementType ? elementType : varPtrType);
860 varTypeAttr = mlir::TypeAttr::get(varPtrType);
868 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
876 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
877 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
881 if (!typeToCheckAgainst)
882 typeToCheckAgainst = varPtrType;
883 if (typeToCheckAgainst != varType) {
891 mlir::SymbolRefAttr &recipeAttr) {
898 mlir::SymbolRefAttr recipeAttr) {
905LogicalResult acc::DataBoundsOp::verify() {
906 auto extent = getExtent();
907 auto upperbound = getUpperbound();
908 if (!extent && !upperbound)
909 return emitError(
"expected extent or upperbound.");
916LogicalResult acc::PrivateOp::verify() {
919 "data clause associated with private operation must match its intent");
933LogicalResult acc::FirstprivateOp::verify() {
935 return emitError(
"data clause associated with firstprivate operation must "
942 *
this,
"firstprivate")))
950LogicalResult acc::ReductionOp::verify() {
952 return emitError(
"data clause associated with reduction operation must "
959 *
this,
"reduction")))
967LogicalResult acc::DevicePtrOp::verify() {
969 return emitError(
"data clause associated with deviceptr operation must "
983LogicalResult acc::PresentOp::verify() {
986 "data clause associated with present operation must match its intent");
999LogicalResult acc::CopyinOp::verify() {
1001 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
1006 "data clause associated with copyin operation must match its intent"
1007 " or specify original clause this operation was decomposed from");
1013 acc::DataClauseModifier::always |
1014 acc::DataClauseModifier::capture)))
1019bool acc::CopyinOp::isCopyinReadonly() {
1020 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
1021 acc::bitEnumContainsAny(getModifiers(),
1022 acc::DataClauseModifier::readonly);
1028LogicalResult acc::CreateOp::verify() {
1035 "data clause associated with create operation must match its intent"
1036 " or specify original clause this operation was decomposed from");
1044 acc::DataClauseModifier::always |
1045 acc::DataClauseModifier::capture)))
1050bool acc::CreateOp::isCreateZero() {
1052 return getDataClause() == acc::DataClause::acc_create_zero ||
1054 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1060LogicalResult acc::NoCreateOp::verify() {
1062 return emitError(
"data clause associated with no_create operation must "
1063 "match its intent");
1076LogicalResult acc::AttachOp::verify() {
1079 "data clause associated with attach operation must match its intent");
1093LogicalResult acc::DeclareDeviceResidentOp::verify() {
1094 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
1095 return emitError(
"data clause associated with device_resident operation "
1096 "must match its intent");
1110LogicalResult acc::DeclareLinkOp::verify() {
1113 "data clause associated with link operation must match its intent");
1126LogicalResult acc::CopyoutOp::verify() {
1133 "data clause associated with copyout operation must match its intent"
1134 " or specify original clause this operation was decomposed from");
1136 return emitError(
"must have both host and device pointers");
1142 acc::DataClauseModifier::always |
1143 acc::DataClauseModifier::capture)))
1148bool acc::CopyoutOp::isCopyoutZero() {
1149 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1150 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1156LogicalResult acc::DeleteOp::verify() {
1165 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1168 "data clause associated with delete operation must match its intent"
1169 " or specify original clause this operation was decomposed from");
1171 return emitError(
"must have device pointer");
1175 acc::DataClauseModifier::readonly |
1176 acc::DataClauseModifier::always |
1177 acc::DataClauseModifier::capture)))
1185LogicalResult acc::DetachOp::verify() {
1190 "data clause associated with detach operation must match its intent"
1191 " or specify original clause this operation was decomposed from");
1193 return emitError(
"must have device pointer");
1202LogicalResult acc::UpdateHostOp::verify() {
1207 "data clause associated with host operation must match its intent"
1208 " or specify original clause this operation was decomposed from");
1210 return emitError(
"must have both host and device pointers");
1223LogicalResult acc::UpdateDeviceOp::verify() {
1227 "data clause associated with device operation must match its intent"
1228 " or specify original clause this operation was decomposed from");
1241LogicalResult acc::UseDeviceOp::verify() {
1245 "data clause associated with use_device operation must match its intent"
1246 " or specify original clause this operation was decomposed from");
1259LogicalResult acc::CacheOp::verify() {
1264 "data clause associated with cache operation must match its intent"
1265 " or specify original clause this operation was decomposed from");
1275bool acc::CacheOp::isCacheReadonly() {
1276 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1277 acc::bitEnumContainsAny(getModifiers(),
1278 acc::DataClauseModifier::readonly);
1294template <
typename EffectTy>
1299 for (
unsigned i = 0, e = operand.
size(); i < e; ++i)
1300 effects.emplace_back(EffectTy::get(), &operand[i]);
1304template <
typename EffectTy>
1309 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(
result));
1313void acc::PrivateOp::getEffects(
1327void acc::FirstprivateOp::getEffects(
1341void acc::ReductionOp::getEffects(
1355void acc::DevicePtrOp::getEffects(
1364void acc::PresentOp::getEffects(
1375void acc::CopyinOp::getEffects(
1388void acc::CreateOp::getEffects(
1401void acc::NoCreateOp::getEffects(
1412void acc::AttachOp::getEffects(
1425void acc::GetDevicePtrOp::getEffects(
1434void acc::UpdateDeviceOp::getEffects(
1444void acc::UseDeviceOp::getEffects(
1453void acc::DeclareDeviceResidentOp::getEffects(
1464void acc::DeclareLinkOp::getEffects(
1475void acc::CacheOp::getEffects(
1480void acc::CopyoutOp::getEffects(
1493void acc::DeleteOp::getEffects(
1505void acc::DetachOp::getEffects(
1517void acc::UpdateHostOp::getEffects(
1529template <
typename StructureOp>
1531 unsigned nRegions = 1) {
1534 for (
unsigned i = 0; i < nRegions; ++i)
1537 for (
Region *region : regions)
1548template <
typename OpTy>
1550 using OpRewritePattern<OpTy>::OpRewritePattern;
1552 LogicalResult matchAndRewrite(OpTy op,
1553 PatternRewriter &rewriter)
const override {
1555 Value ifCond = op.getIfCond();
1559 IntegerAttr constAttr;
1562 if (constAttr.getInt())
1563 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1575 assert(region.
hasOneBlock() &&
"expected single-block region");
1587template <
typename OpTy>
1588struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1589 using OpRewritePattern<OpTy>::OpRewritePattern;
1591 LogicalResult matchAndRewrite(OpTy op,
1592 PatternRewriter &rewriter)
const override {
1594 Value ifCond = op.getIfCond();
1598 IntegerAttr constAttr;
1601 if (constAttr.getInt())
1602 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1630 for (
Value bound : bounds) {
1631 argTypes.push_back(bound.getType());
1632 argLocs.push_back(loc);
1639 Value privatizedValue;
1645 if (isa<MappableType>(varType)) {
1646 auto mappableTy = cast<MappableType>(varType);
1647 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1648 auto typedHostVar = cast<TypedValue<MappableType>>(hostVar);
1649 varInfo = mappableTy.genPrivateVariableInfo(typedHostVar);
1650 privatizedValue = mappableTy.generatePrivateInit(
1651 builder, loc, typedVar, varName, bounds, {}, varInfo, needsFree);
1652 if (!privatizedValue)
1655 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1656 auto pointerLikeTy = cast<PointerLikeType>(varType);
1658 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1659 blockArgVar, needsFree);
1660 if (!privatizedValue)
1665 acc::YieldOp::create(builder, loc, privatizedValue);
1682 for (
Value bound : bounds) {
1683 copyArgTypes.push_back(bound.getType());
1684 copyArgLocs.push_back(loc);
1694 if (isa<MappableType>(varType)) {
1695 auto mappableTy = cast<MappableType>(varType);
1698 if (!mappableTy.generateCopy(
1703 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1704 auto pointerLikeTy = cast<PointerLikeType>(varType);
1705 if (!pointerLikeTy.genCopy(
1712 acc::TerminatorOp::create(builder, loc);
1729 for (
Value bound : bounds) {
1730 destroyArgTypes.push_back(bound.getType());
1731 destroyArgLocs.push_back(loc);
1735 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1739 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1740 if (isa<MappableType>(varType)) {
1741 auto mappableTy = cast<MappableType>(varType);
1742 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds,
1746 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1747 auto pointerLikeTy = cast<PointerLikeType>(varType);
1748 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1752 acc::TerminatorOp::create(builder, loc);
1763 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1765 if (optional && region.
empty())
1769 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1773 return op->
emitOpError() <<
"expects " << regionName
1776 << regionType <<
" type";
1779 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1780 if (yieldOp.getOperands().size() != 1 ||
1781 yieldOp.getOperands().getTypes()[0] != type)
1782 return op->
emitOpError() <<
"expects " << regionName
1784 "yield a value of the "
1785 << regionType <<
" type";
1791LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1793 "privatization",
"init",
getType(),
1797 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1803std::optional<PrivateRecipeOp>
1805 StringRef recipeName,
Value hostVar,
1810 bool isMappable = isa<MappableType>(varType);
1811 bool isPointerLike = isa<PointerLikeType>(varType);
1814 if (!isMappable && !isPointerLike)
1815 return std::nullopt;
1820 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1823 bool needsFree =
false;
1825 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1826 varName, bounds, needsFree, varInfo))) {
1828 return std::nullopt;
1835 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1836 Value allocRes = yieldOp.getOperand(0);
1838 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1839 varType, allocRes, bounds, varInfo))) {
1841 return std::nullopt;
1848std::optional<PrivateRecipeOp>
1850 StringRef recipeName,
1851 FirstprivateRecipeOp firstprivRecipe) {
1854 auto varType = firstprivRecipe.getType();
1855 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1859 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1862 if (!firstprivRecipe.getDestroyRegion().empty()) {
1864 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1874LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1876 "privatization",
"init",
getType(),
1880 if (getCopyRegion().empty())
1881 return emitOpError() <<
"expects non-empty copy region";
1886 return emitOpError() <<
"expects copy region with two arguments of the "
1887 "privatization type";
1889 if (getDestroyRegion().empty())
1893 "privatization",
"destroy",
1900std::optional<FirstprivateRecipeOp>
1902 StringRef recipeName,
Value hostVar,
1907 bool isMappable = isa<MappableType>(varType);
1908 bool isPointerLike = isa<PointerLikeType>(varType);
1911 if (!isMappable && !isPointerLike)
1912 return std::nullopt;
1917 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1920 bool needsFree =
false;
1925 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1926 varName, bounds, needsFree, varInfo))) {
1928 return std::nullopt;
1932 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1933 bounds, varInfo))) {
1935 return std::nullopt;
1942 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1943 Value allocRes = yieldOp.getOperand(0);
1945 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1946 varType, allocRes, bounds, varInfo))) {
1948 return std::nullopt;
1959LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1965 if (getCombinerRegion().empty())
1966 return emitOpError() <<
"expects non-empty combiner region";
1968 Block &reductionBlock = getCombinerRegion().
front();
1972 return emitOpError() <<
"expects combiner region with the first two "
1973 <<
"arguments of the reduction type";
1975 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1976 if (yieldOp.getOperands().size() != 1 ||
1977 yieldOp.getOperands().getTypes()[0] !=
getType())
1978 return emitOpError() <<
"expects combiner region to yield a value "
1979 "of the reduction type";
1990template <
typename Op>
1994 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1995 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1996 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1997 operand.getDefiningOp()))
1999 "expect data entry/exit operation or acc.getdeviceptr "
2004template <
typename OpT,
typename RecipeOpT>
2007 llvm::StringRef operandName) {
2010 if (!mlir::isa<OpT>(operand.getDefiningOp()))
2012 <<
"expected " << operandName <<
" as defining op";
2013 if (!set.insert(operand).second)
2015 << operandName <<
" operand appears more than once";
2020unsigned ParallelOp::getNumDataOperands() {
2021 return getReductionOperands().size() + getPrivateOperands().size() +
2022 getFirstprivateOperands().size() + getDataClauseOperands().size();
2025Value ParallelOp::getDataOperand(
unsigned i) {
2027 numOptional += getNumGangs().size();
2028 numOptional += getNumWorkers().size();
2029 numOptional += getVectorLength().size();
2030 numOptional += getIfCond() ? 1 : 0;
2031 numOptional += getSelfCond() ? 1 : 0;
2032 return getOperand(getWaitOperands().size() + numOptional + i);
2035template <
typename Op>
2038 llvm::StringRef keyword) {
2039 if (!operands.empty() &&
2040 (!deviceTypes || deviceTypes.getValue().size() != operands.size()))
2041 return op.
emitOpError() << keyword <<
" operands count must match "
2042 << keyword <<
" device_type count";
2046template <
typename Op>
2049 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
2050 std::size_t numOperandsInSegments = 0;
2051 std::size_t nbOfSegments = 0;
2054 for (
auto segCount : segments.
asArrayRef()) {
2055 if (maxInSegment != 0 && segCount > maxInSegment)
2056 return op.
emitOpError() << keyword <<
" expects a maximum of "
2057 << maxInSegment <<
" values per segment";
2058 numOperandsInSegments += segCount;
2063 if ((numOperandsInSegments != operands.size()) ||
2064 (!deviceTypes && !operands.empty()))
2066 << keyword <<
" operand count does not match count in segments";
2067 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
2069 << keyword <<
" segment count does not match device_type count";
2073LogicalResult acc::ParallelOp::verify() {
2075 mlir::acc::PrivateRecipeOp>(
2076 *
this, getPrivateOperands(),
"private")))
2079 mlir::acc::FirstprivateRecipeOp>(
2080 *
this, getFirstprivateOperands(),
"firstprivate")))
2083 mlir::acc::ReductionRecipeOp>(
2084 *
this, getReductionOperands(),
"reduction")))
2088 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2089 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2093 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2094 getWaitOperandsDeviceTypeAttr(),
"wait")))
2098 getNumWorkersDeviceTypeAttr(),
2103 getVectorLengthDeviceTypeAttr(),
2108 getAsyncOperandsDeviceTypeAttr(),
2121 mlir::acc::DeviceType deviceType) {
2124 if (
auto pos =
findSegment(*arrayAttr, deviceType))
2129bool acc::ParallelOp::hasAsyncOnly() {
2130 return hasAsyncOnly(mlir::acc::DeviceType::None);
2133bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2138 return getAsyncValue(mlir::acc::DeviceType::None);
2141mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2146mlir::Value acc::ParallelOp::getNumWorkersValue() {
2147 return getNumWorkersValue(mlir::acc::DeviceType::None);
2151acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2156mlir::Value acc::ParallelOp::getVectorLengthValue() {
2157 return getVectorLengthValue(mlir::acc::DeviceType::None);
2161acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2163 getVectorLength(), deviceType);
2167 return getNumGangsValues(mlir::acc::DeviceType::None);
2171ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2173 getNumGangsSegments(), deviceType);
2176bool acc::ParallelOp::hasWaitOnly() {
2177 return hasWaitOnly(mlir::acc::DeviceType::None);
2180bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2185 return getWaitValues(mlir::acc::DeviceType::None);
2189ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2191 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2192 getHasWaitDevnum(), deviceType);
2196 return getWaitDevnum(mlir::acc::DeviceType::None);
2199mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2201 getWaitOperandsSegments(), getHasWaitDevnum(),
2216 odsBuilder, odsState, asyncOperands,
nullptr,
2217 nullptr, waitOperands,
nullptr,
2219 nullptr, numGangs,
nullptr,
2220 nullptr, numWorkers,
2221 nullptr, vectorLength,
2222 nullptr, ifCond, selfCond,
2223 nullptr, reductionOperands, gangPrivateOperands,
2224 gangFirstPrivateOperands, dataClauseOperands,
2228void acc::ParallelOp::addNumWorkersOperand(
2231 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2232 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2233 getNumWorkersMutable()));
2235void acc::ParallelOp::addVectorLengthOperand(
2238 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2239 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2240 getVectorLengthMutable()));
2243void acc::ParallelOp::addAsyncOnly(
2245 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2246 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2249void acc::ParallelOp::addAsyncOperand(
2252 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2253 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2254 getAsyncOperandsMutable()));
2257void acc::ParallelOp::addNumGangsOperands(
2261 if (getNumGangsSegments())
2262 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2264 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2265 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2266 getNumGangsMutable(), segments));
2268 setNumGangsSegments(segments);
2270void acc::ParallelOp::addWaitOnly(
2272 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2273 effectiveDeviceTypes));
2275void acc::ParallelOp::addWaitOperands(
2280 if (getWaitOperandsSegments())
2281 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2283 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2284 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2285 getWaitOperandsMutable(), segments));
2286 setWaitOperandsSegments(segments);
2289 if (getHasWaitDevnumAttr())
2290 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2293 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2295 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2298void acc::ParallelOp::addPrivatization(
MLIRContext *context,
2299 mlir::acc::PrivateOp op,
2300 mlir::acc::PrivateRecipeOp recipe) {
2301 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2302 getPrivateOperandsMutable().append(op.getResult());
2305void acc::ParallelOp::addFirstPrivatization(
2306 MLIRContext *context, mlir::acc::FirstprivateOp op,
2307 mlir::acc::FirstprivateRecipeOp recipe) {
2308 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2309 getFirstprivateOperandsMutable().append(op.getResult());
2312void acc::ParallelOp::addReduction(
MLIRContext *context,
2313 mlir::acc::ReductionOp op,
2314 mlir::acc::ReductionRecipeOp recipe) {
2315 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2316 getReductionOperandsMutable().append(op.getResult());
2331 int32_t crtOperandsSize = operands.size();
2334 if (parser.parseOperand(operands.emplace_back()) ||
2335 parser.parseColonType(types.emplace_back()))
2340 seg.push_back(operands.size() - crtOperandsSize);
2350 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2351 parser.
getContext(), mlir::acc::DeviceType::None));
2357 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2364 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2365 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2366 p <<
" [" << attr <<
"]";
2371 std::optional<mlir::ArrayAttr> deviceTypes,
2372 std::optional<mlir::DenseI32ArrayAttr> segments) {
2374 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2376 llvm::interleaveComma(
2377 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2378 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2398 int32_t crtOperandsSize = operands.size();
2402 if (parser.parseOperand(operands.emplace_back()) ||
2403 parser.parseColonType(types.emplace_back()))
2409 seg.push_back(operands.size() - crtOperandsSize);
2419 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2420 parser.
getContext(), mlir::acc::DeviceType::None));
2426 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2435 std::optional<mlir::DenseI32ArrayAttr> segments) {
2437 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2439 llvm::interleaveComma(
2440 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2441 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2454 mlir::ArrayAttr &keywordOnly) {
2458 bool needCommaBeforeOperands =
false;
2462 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2463 parser.
getContext(), mlir::acc::DeviceType::None));
2464 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2471 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2478 needCommaBeforeOperands =
true;
2481 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2488 int32_t crtOperandsSize = operands.size();
2500 if (parser.parseOperand(operands.emplace_back()) ||
2501 parser.parseColonType(types.emplace_back()))
2507 seg.push_back(operands.size() - crtOperandsSize);
2517 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2518 parser.
getContext(), mlir::acc::DeviceType::None));
2525 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2526 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2528 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2536 if (attrs->size() != 1)
2538 if (
auto deviceTypeAttr =
2539 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2540 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2546 std::optional<mlir::ArrayAttr> deviceTypes,
2547 std::optional<mlir::DenseI32ArrayAttr> segments,
2548 std::optional<mlir::ArrayAttr> hasDevNum,
2549 std::optional<mlir::ArrayAttr> keywordOnly) {
2562 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2564 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2565 if (boolAttr && boolAttr.getValue())
2567 llvm::interleaveComma(
2568 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2569 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2586 if (parser.parseOperand(operands.emplace_back()) ||
2587 parser.parseColonType(types.emplace_back()))
2589 if (succeeded(parser.parseOptionalLSquare())) {
2590 if (parser.parseAttribute(attributes.emplace_back()) ||
2591 parser.parseRSquare())
2594 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2595 parser.getContext(), mlir::acc::DeviceType::None));
2602 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2609 std::optional<mlir::ArrayAttr> deviceTypes) {
2612 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2613 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2622 mlir::ArrayAttr &keywordOnlyDeviceType) {
2625 bool needCommaBeforeOperands =
false;
2629 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2630 parser.
getContext(), mlir::acc::DeviceType::None));
2631 keywordOnlyDeviceType =
2632 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2640 if (parser.parseAttribute(
2641 keywordOnlyDeviceTypeAttributes.emplace_back()))
2648 needCommaBeforeOperands =
true;
2651 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2656 if (parser.parseOperand(operands.emplace_back()) ||
2657 parser.parseColonType(types.emplace_back()))
2659 if (succeeded(parser.parseOptionalLSquare())) {
2660 if (parser.parseAttribute(attributes.emplace_back()) ||
2661 parser.parseRSquare())
2664 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2665 parser.getContext(), mlir::acc::DeviceType::None));
2671 if (
failed(parser.parseRParen()))
2676 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2683 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2685 if (operands.begin() == operands.end() &&
2701 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2702 mlir::Type &operandType, mlir::UnitAttr &attr) {
2705 attr = mlir::UnitAttr::get(parser.
getContext());
2715 if (failed(parser.
parseType(operandType)))
2725 std::optional<mlir::Value> operand,
2727 mlir::UnitAttr attr) {
2744 attr = mlir::UnitAttr::get(parser.
getContext());
2749 if (parser.parseOperand(operands.emplace_back()))
2757 if (parser.parseType(types.emplace_back()))
2772 mlir::UnitAttr attr) {
2777 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2779 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2785 mlir::acc::CombinedConstructsTypeAttr &attr) {
2787 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2788 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2790 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2791 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2793 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2794 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2797 "expected compute construct name");
2805 mlir::acc::CombinedConstructsTypeAttr attr) {
2807 switch (attr.getValue()) {
2808 case mlir::acc::CombinedConstructsType::KernelsLoop:
2811 case mlir::acc::CombinedConstructsType::ParallelLoop:
2814 case mlir::acc::CombinedConstructsType::SerialLoop:
2825unsigned SerialOp::getNumDataOperands() {
2826 return getReductionOperands().size() + getPrivateOperands().size() +
2827 getFirstprivateOperands().size() + getDataClauseOperands().size();
2830Value SerialOp::getDataOperand(
unsigned i) {
2832 numOptional += getIfCond() ? 1 : 0;
2833 numOptional += getSelfCond() ? 1 : 0;
2834 return getOperand(getWaitOperands().size() + numOptional + i);
2837bool acc::SerialOp::hasAsyncOnly() {
2838 return hasAsyncOnly(mlir::acc::DeviceType::None);
2841bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2846 return getAsyncValue(mlir::acc::DeviceType::None);
2849mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2854bool acc::SerialOp::hasWaitOnly() {
2855 return hasWaitOnly(mlir::acc::DeviceType::None);
2858bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2863 return getWaitValues(mlir::acc::DeviceType::None);
2867SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2869 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2870 getHasWaitDevnum(), deviceType);
2874 return getWaitDevnum(mlir::acc::DeviceType::None);
2877mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2879 getWaitOperandsSegments(), getHasWaitDevnum(),
2883LogicalResult acc::SerialOp::verify() {
2885 mlir::acc::PrivateRecipeOp>(
2886 *
this, getPrivateOperands(),
"private")))
2889 mlir::acc::FirstprivateRecipeOp>(
2890 *
this, getFirstprivateOperands(),
"firstprivate")))
2893 mlir::acc::ReductionRecipeOp>(
2894 *
this, getReductionOperands(),
"reduction")))
2898 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2899 getWaitOperandsDeviceTypeAttr(),
"wait")))
2903 getAsyncOperandsDeviceTypeAttr(),
2913void acc::SerialOp::addAsyncOnly(
2915 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2916 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2919void acc::SerialOp::addAsyncOperand(
2922 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2923 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2924 getAsyncOperandsMutable()));
2927void acc::SerialOp::addWaitOnly(
2929 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2930 effectiveDeviceTypes));
2932void acc::SerialOp::addWaitOperands(
2937 if (getWaitOperandsSegments())
2938 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2940 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2941 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2942 getWaitOperandsMutable(), segments));
2943 setWaitOperandsSegments(segments);
2946 if (getHasWaitDevnumAttr())
2947 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2950 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2952 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2955void acc::SerialOp::addPrivatization(
MLIRContext *context,
2956 mlir::acc::PrivateOp op,
2957 mlir::acc::PrivateRecipeOp recipe) {
2958 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2959 getPrivateOperandsMutable().append(op.getResult());
2962void acc::SerialOp::addFirstPrivatization(
2963 MLIRContext *context, mlir::acc::FirstprivateOp op,
2964 mlir::acc::FirstprivateRecipeOp recipe) {
2965 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2966 getFirstprivateOperandsMutable().append(op.getResult());
2969void acc::SerialOp::addReduction(
MLIRContext *context,
2970 mlir::acc::ReductionOp op,
2971 mlir::acc::ReductionRecipeOp recipe) {
2972 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2973 getReductionOperandsMutable().append(op.getResult());
2980unsigned KernelsOp::getNumDataOperands() {
2981 return getDataClauseOperands().size();
2984Value KernelsOp::getDataOperand(
unsigned i) {
2986 numOptional += getWaitOperands().size();
2987 numOptional += getNumGangs().size();
2988 numOptional += getNumWorkers().size();
2989 numOptional += getVectorLength().size();
2990 numOptional += getIfCond() ? 1 : 0;
2991 numOptional += getSelfCond() ? 1 : 0;
2992 return getOperand(numOptional + i);
2995bool acc::KernelsOp::hasAsyncOnly() {
2996 return hasAsyncOnly(mlir::acc::DeviceType::None);
2999bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3004 return getAsyncValue(mlir::acc::DeviceType::None);
3007mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3013 return getNumWorkersValue(mlir::acc::DeviceType::None);
3017acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
3022mlir::Value acc::KernelsOp::getVectorLengthValue() {
3023 return getVectorLengthValue(mlir::acc::DeviceType::None);
3027acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
3029 getVectorLength(), deviceType);
3033 return getNumGangsValues(mlir::acc::DeviceType::None);
3037KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
3039 getNumGangsSegments(), deviceType);
3042bool acc::KernelsOp::hasWaitOnly() {
3043 return hasWaitOnly(mlir::acc::DeviceType::None);
3046bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3051 return getWaitValues(mlir::acc::DeviceType::None);
3055KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3057 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3058 getHasWaitDevnum(), deviceType);
3062 return getWaitDevnum(mlir::acc::DeviceType::None);
3065mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3067 getWaitOperandsSegments(), getHasWaitDevnum(),
3071LogicalResult acc::KernelsOp::verify() {
3073 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
3074 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
3078 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3079 getWaitOperandsDeviceTypeAttr(),
"wait")))
3083 getNumWorkersDeviceTypeAttr(),
3088 getVectorLengthDeviceTypeAttr(),
3093 getAsyncOperandsDeviceTypeAttr(),
3103void acc::KernelsOp::addPrivatization(
MLIRContext *context,
3104 mlir::acc::PrivateOp op,
3105 mlir::acc::PrivateRecipeOp recipe) {
3106 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3107 getPrivateOperandsMutable().append(op.getResult());
3110void acc::KernelsOp::addFirstPrivatization(
3111 MLIRContext *context, mlir::acc::FirstprivateOp op,
3112 mlir::acc::FirstprivateRecipeOp recipe) {
3113 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3114 getFirstprivateOperandsMutable().append(op.getResult());
3117void acc::KernelsOp::addReduction(
MLIRContext *context,
3118 mlir::acc::ReductionOp op,
3119 mlir::acc::ReductionRecipeOp recipe) {
3120 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3121 getReductionOperandsMutable().append(op.getResult());
3124void acc::KernelsOp::addNumWorkersOperand(
3127 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3128 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3129 getNumWorkersMutable()));
3132void acc::KernelsOp::addVectorLengthOperand(
3135 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3136 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3137 getVectorLengthMutable()));
3139void acc::KernelsOp::addAsyncOnly(
3141 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3142 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3145void acc::KernelsOp::addAsyncOperand(
3148 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3149 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3150 getAsyncOperandsMutable()));
3153void acc::KernelsOp::addNumGangsOperands(
3157 if (getNumGangsSegmentsAttr())
3158 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3160 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3161 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3162 getNumGangsMutable(), segments));
3164 setNumGangsSegments(segments);
3167void acc::KernelsOp::addWaitOnly(
3169 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3170 effectiveDeviceTypes));
3172void acc::KernelsOp::addWaitOperands(
3177 if (getWaitOperandsSegments())
3178 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3180 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3181 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3182 getWaitOperandsMutable(), segments));
3183 setWaitOperandsSegments(segments);
3186 if (getHasWaitDevnumAttr())
3187 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3190 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3192 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3199LogicalResult acc::HostDataOp::verify() {
3200 if (getDataClauseOperands().empty())
3201 return emitError(
"at least one operand must appear on the host_data "
3205 for (
mlir::Value operand : getDataClauseOperands()) {
3207 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3209 return emitError(
"expect data entry operation as defining op");
3212 if (!seenVars.insert(useDeviceOp.getVar()).second)
3213 return emitError(
"duplicate use_device variable");
3220 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3232 bool &needCommaBetweenValues,
bool &newValue) {
3239 attributes.push_back(gangArgType);
3240 needCommaBetweenValues =
true;
3251 mlir::ArrayAttr &gangOnlyDeviceType) {
3256 bool needCommaBetweenValues =
false;
3257 bool needCommaBeforeOperands =
false;
3261 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3262 parser.
getContext(), mlir::acc::DeviceType::None));
3263 gangOnlyDeviceType =
3264 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
3272 if (parser.parseAttribute(
3273 gangOnlyDeviceTypeAttributes.emplace_back()))
3280 needCommaBeforeOperands =
true;
3283 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3284 mlir::acc::GangArgType::Num);
3285 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3286 mlir::acc::GangArgType::Dim);
3287 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3288 parser.
getContext(), mlir::acc::GangArgType::Static);
3291 if (needCommaBeforeOperands) {
3292 needCommaBeforeOperands =
false;
3299 int32_t crtOperandsSize = gangOperands.size();
3301 bool newValue =
false;
3302 bool needValue =
false;
3303 if (needCommaBetweenValues) {
3311 gangOperands, gangOperandsType,
3312 gangArgTypeAttributes, argNum,
3313 needCommaBetweenValues, newValue)))
3316 gangOperands, gangOperandsType,
3317 gangArgTypeAttributes, argDim,
3318 needCommaBetweenValues, newValue)))
3320 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3321 gangOperands, gangOperandsType,
3322 gangArgTypeAttributes, argStatic,
3323 needCommaBetweenValues, newValue)))
3326 if (!newValue && needValue) {
3328 "new value expected after comma");
3336 if (gangOperands.empty())
3339 "expect at least one of num, dim or static values");
3345 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
3349 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3350 parser.
getContext(), mlir::acc::DeviceType::None));
3353 seg.push_back(gangOperands.size() - crtOperandsSize);
3361 gangArgTypeAttributes.end());
3362 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
3363 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
3366 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3367 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
3375 std::optional<mlir::ArrayAttr> gangArgTypes,
3376 std::optional<mlir::ArrayAttr> deviceTypes,
3377 std::optional<mlir::DenseI32ArrayAttr> segments,
3378 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3380 if (operands.begin() == operands.end() &&
3395 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
3397 llvm::interleaveComma(
3398 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
3399 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3400 (*gangArgTypes)[opIdx]);
3401 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3402 p << LoopOp::getGangNumKeyword();
3403 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3404 p << LoopOp::getGangDimKeyword();
3405 else if (gangArgTypeAttr.getValue() ==
3406 mlir::acc::GangArgType::Static)
3407 p << LoopOp::getGangStaticKeyword();
3408 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
3419 std::optional<mlir::ArrayAttr> segments,
3420 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3423 for (
auto attr : *segments) {
3424 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3425 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3433static std::optional<mlir::acc::DeviceType>
3435 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3437 return std::nullopt;
3438 for (
auto attr : deviceTypes) {
3439 auto deviceTypeAttr =
3440 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3441 if (!deviceTypeAttr)
3442 return mlir::acc::DeviceType::None;
3443 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3444 return deviceTypeAttr.getValue();
3446 return std::nullopt;
3449LogicalResult acc::LoopOp::verify() {
3450 if (getUpperbound().size() != getStep().size())
3451 return emitError() <<
"number of upperbounds expected to be the same as "
3454 if (getUpperbound().size() != getLowerbound().size())
3455 return emitError() <<
"number of upperbounds expected to be the same as "
3456 "number of lowerbounds";
3458 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3459 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3460 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3461 <<
" as upperbound size";
3464 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3465 return emitOpError() <<
"collapse device_type attr must be define when"
3466 <<
" collapse attr is present";
3468 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3469 getCollapseAttr().getValue().size() !=
3470 getCollapseDeviceTypeAttr().getValue().size())
3471 return emitOpError() <<
"collapse attribute count must match collapse"
3472 <<
" device_type count";
3473 if (
auto duplicateDeviceType =
checkDeviceTypes(getCollapseDeviceTypeAttr()))
3475 << acc::stringifyDeviceType(*duplicateDeviceType)
3476 <<
"` found in collapseDeviceType attribute";
3479 if (!getGangOperands().empty()) {
3480 if (!getGangOperandsArgType())
3481 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3482 <<
" when gang operands are present";
3484 if (getGangOperands().size() !=
3485 getGangOperandsArgTypeAttr().getValue().size())
3486 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3487 <<
" gangOperands count";
3489 if (getGangAttr()) {
3492 << acc::stringifyDeviceType(*duplicateDeviceType)
3493 <<
"` found in gang attribute";
3497 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3498 getGangOperandsDeviceTypeAttr(),
"gang")))
3504 << acc::stringifyDeviceType(*duplicateDeviceType)
3505 <<
"` found in worker attribute";
3506 if (
auto duplicateDeviceType =
3509 << acc::stringifyDeviceType(*duplicateDeviceType)
3510 <<
"` found in workerNumOperandsDeviceType attribute";
3512 getWorkerNumOperandsDeviceTypeAttr(),
3519 << acc::stringifyDeviceType(*duplicateDeviceType)
3520 <<
"` found in vector attribute";
3521 if (
auto duplicateDeviceType =
3524 << acc::stringifyDeviceType(*duplicateDeviceType)
3525 <<
"` found in vectorOperandsDeviceType attribute";
3527 getVectorOperandsDeviceTypeAttr(),
3532 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3533 getTileOperandsDeviceTypeAttr(),
"tile")))
3537 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3541 return emitError() <<
"only one of auto, independent, seq can be present "
3547 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3548 return attr.getValue() == mlir::acc::DeviceType::None;
3550 bool hasDefaultSeq =
3552 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3555 bool hasDefaultIndependent =
3556 getIndependentAttr()
3558 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3561 bool hasDefaultAuto =
3563 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3566 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3568 <<
"at least one of auto, independent, seq must be present";
3573 for (
auto attr : getSeqAttr()) {
3574 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3575 if (hasVector(deviceTypeAttr.getValue()) ||
3576 getVectorValue(deviceTypeAttr.getValue()) ||
3577 hasWorker(deviceTypeAttr.getValue()) ||
3578 getWorkerValue(deviceTypeAttr.getValue()) ||
3579 hasGang(deviceTypeAttr.getValue()) ||
3580 getGangValue(mlir::acc::GangArgType::Num,
3581 deviceTypeAttr.getValue()) ||
3582 getGangValue(mlir::acc::GangArgType::Dim,
3583 deviceTypeAttr.getValue()) ||
3584 getGangValue(mlir::acc::GangArgType::Static,
3585 deviceTypeAttr.getValue()))
3586 return emitError() <<
"gang, worker or vector cannot appear with seq";
3591 mlir::acc::PrivateRecipeOp>(
3592 *
this, getPrivateOperands(),
"private")))
3596 mlir::acc::FirstprivateRecipeOp>(
3597 *
this, getFirstprivateOperands(),
"firstprivate")))
3601 mlir::acc::ReductionRecipeOp>(
3602 *
this, getReductionOperands(),
"reduction")))
3605 if (getCombined().has_value() &&
3606 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3607 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3608 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3609 return emitError(
"unexpected combined constructs attribute");
3613 if (getRegion().empty())
3614 return emitError(
"expected non-empty body.");
3616 if (getUnstructured()) {
3617 if (!isContainerLike())
3619 "unstructured acc.loop must not have induction variables");
3620 }
else if (isContainerLike()) {
3624 uint64_t collapseCount = getCollapseValue().value_or(1);
3625 if (getCollapseAttr()) {
3626 for (
auto collapseEntry : getCollapseAttr()) {
3627 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3628 if (intAttr.getValue().getZExtValue() > collapseCount)
3629 collapseCount = intAttr.getValue().getZExtValue();
3637 bool foundSibling =
false;
3639 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3641 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3643 foundSibling =
true;
3648 expectedParent = op;
3651 if (collapseCount == 0)
3657 return emitError(
"found sibling loops inside container-like acc.loop");
3658 if (collapseCount != 0)
3659 return emitError(
"failed to find enough loop-like operations inside "
3660 "container-like acc.loop");
3666unsigned LoopOp::getNumDataOperands() {
3667 return getReductionOperands().size() + getPrivateOperands().size() +
3668 getFirstprivateOperands().size();
3671Value LoopOp::getDataOperand(
unsigned i) {
3672 unsigned numOptional =
3673 getLowerbound().size() + getUpperbound().size() + getStep().size();
3674 numOptional += getGangOperands().size();
3675 numOptional += getVectorOperands().size();
3676 numOptional += getWorkerNumOperands().size();
3677 numOptional += getTileOperands().size();
3678 numOptional += getCacheOperands().size();
3679 return getOperand(numOptional + i);
3682bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3684bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3688bool LoopOp::hasIndependent() {
3689 return hasIndependent(mlir::acc::DeviceType::None);
3692bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3696bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3698bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3703 return getVectorValue(mlir::acc::DeviceType::None);
3706mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3708 getVectorOperands(), deviceType);
3711bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3713bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3718 return getWorkerValue(mlir::acc::DeviceType::None);
3721mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3723 getWorkerNumOperands(), deviceType);
3726bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3728bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3733 return getTileValues(mlir::acc::DeviceType::None);
3737LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3739 getTileOperandsSegments(), deviceType);
3742std::optional<int64_t> LoopOp::getCollapseValue() {
3743 return getCollapseValue(mlir::acc::DeviceType::None);
3746std::optional<int64_t>
3747LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3748 if (!getCollapseAttr())
3749 return std::nullopt;
3750 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3752 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3753 return intAttr.getValue().getZExtValue();
3755 return std::nullopt;
3758mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3759 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3762mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3763 mlir::acc::DeviceType deviceType) {
3764 if (getGangOperands().empty())
3766 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3767 int32_t nbOperandsBefore = 0;
3768 for (
unsigned i = 0; i < *pos; ++i)
3769 nbOperandsBefore += (*getGangOperandsSegments())[i];
3772 .drop_front(nbOperandsBefore)
3773 .take_front((*getGangOperandsSegments())[*pos]);
3775 int32_t argTypeIdx = nbOperandsBefore;
3776 for (
auto value : values) {
3777 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3778 (*getGangOperandsArgType())[argTypeIdx]);
3779 if (gangArgTypeAttr.getValue() == gangArgType)
3787bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3789bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3794 return {&getRegion()};
3838 if (!regionArgs.empty()) {
3839 p << acc::LoopOp::getControlKeyword() <<
"(";
3840 llvm::interleaveComma(regionArgs, p,
3842 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3843 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3844 <<
" : " << stepType <<
") ";
3851 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3852 effectiveDeviceTypes));
3855void acc::LoopOp::addIndependent(
3857 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3858 context, getIndependentAttr(), effectiveDeviceTypes));
3863 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3864 effectiveDeviceTypes));
3867void acc::LoopOp::setCollapseForDeviceTypes(
3869 llvm::APInt value) {
3873 assert((getCollapseAttr() ==
nullptr) ==
3874 (getCollapseDeviceTypeAttr() ==
nullptr));
3875 assert(value.getBitWidth() == 64);
3877 if (getCollapseAttr()) {
3878 for (
const auto &existing :
3879 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3880 newValues.push_back(std::get<0>(existing));
3881 newDeviceTypes.push_back(std::get<1>(existing));
3885 if (effectiveDeviceTypes.empty()) {
3888 newValues.push_back(
3889 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3890 newDeviceTypes.push_back(
3891 acc::DeviceTypeAttr::get(context, DeviceType::None));
3893 for (DeviceType dt : effectiveDeviceTypes) {
3894 newValues.push_back(
3895 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3896 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3900 setCollapseAttr(ArrayAttr::get(context, newValues));
3901 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3904void acc::LoopOp::setTileForDeviceTypes(
3908 if (getTileOperandsSegments())
3909 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3911 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3912 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3913 getTileOperandsMutable(), segments));
3915 setTileOperandsSegments(segments);
3918void acc::LoopOp::addVectorOperand(
3921 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3922 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3923 newValue, getVectorOperandsMutable()));
3926void acc::LoopOp::addEmptyVector(
3928 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3929 effectiveDeviceTypes));
3932void acc::LoopOp::addWorkerNumOperand(
3935 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3936 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3937 newValue, getWorkerNumOperandsMutable()));
3940void acc::LoopOp::addEmptyWorker(
3942 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3943 effectiveDeviceTypes));
3946void acc::LoopOp::addEmptyGang(
3948 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3949 effectiveDeviceTypes));
3952bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3953 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3954 return attr.getValue() == dt;
3956 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3957 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3960 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3962 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3964 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3970bool acc::LoopOp::hasDefaultGangWorkerVector() {
3971 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3972 hasGang() || getGangValue(GangArgType::Num) ||
3973 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3977acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3978 if (hasSeq(deviceType))
3979 return LoopParMode::loop_seq;
3980 if (hasAuto(deviceType))
3981 return LoopParMode::loop_auto;
3982 if (hasIndependent(deviceType))
3983 return LoopParMode::loop_independent;
3985 return LoopParMode::loop_seq;
3987 return LoopParMode::loop_auto;
3988 assert(hasIndependent() &&
3989 "loop must have default auto, seq, or independent");
3990 return LoopParMode::loop_independent;
3993void acc::LoopOp::addGangOperands(
3998 getGangOperandsSegments())
3999 llvm::copy(*existingSegments, std::back_inserter(segments));
4001 unsigned beforeCount = segments.size();
4003 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4004 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
4005 getGangOperandsMutable(), segments));
4007 setGangOperandsSegments(segments);
4014 unsigned numAdded = segments.size() - beforeCount;
4018 if (getGangOperandsArgTypeAttr())
4019 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
4021 for (
auto i : llvm::index_range(0u, numAdded)) {
4022 llvm::transform(argTypes, std::back_inserter(gangTypes),
4023 [=](mlir::acc::GangArgType gangTy) {
4024 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
4029 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
4033void acc::LoopOp::addPrivatization(
MLIRContext *context,
4034 mlir::acc::PrivateOp op,
4035 mlir::acc::PrivateRecipeOp recipe) {
4036 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4037 getPrivateOperandsMutable().append(op.getResult());
4040void acc::LoopOp::addFirstPrivatization(
4041 MLIRContext *context, mlir::acc::FirstprivateOp op,
4042 mlir::acc::FirstprivateRecipeOp recipe) {
4043 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4044 getFirstprivateOperandsMutable().append(op.getResult());
4047void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
4048 mlir::acc::ReductionRecipeOp recipe) {
4049 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4050 getReductionOperandsMutable().append(op.getResult());
4057LogicalResult acc::DataOp::verify() {
4062 return emitError(
"at least one operand or the default attribute "
4063 "must appear on the data operation");
4065 for (
mlir::Value operand : getDataClauseOperands())
4066 if (isa<BlockArgument>(operand) ||
4067 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4068 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
4069 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
4070 operand.getDefiningOp()))
4071 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4080unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
4082Value DataOp::getDataOperand(
unsigned i) {
4083 unsigned numOptional = getIfCond() ? 1 : 0;
4085 numOptional += getWaitOperands().size();
4086 return getOperand(numOptional + i);
4089bool acc::DataOp::hasAsyncOnly() {
4090 return hasAsyncOnly(mlir::acc::DeviceType::None);
4093bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4098 return getAsyncValue(mlir::acc::DeviceType::None);
4101mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4106bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
4108bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4113 return getWaitValues(mlir::acc::DeviceType::None);
4117DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4119 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4120 getHasWaitDevnum(), deviceType);
4124 return getWaitDevnum(mlir::acc::DeviceType::None);
4127mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4129 getWaitOperandsSegments(), getHasWaitDevnum(),
4133void acc::DataOp::addAsyncOnly(
4135 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4136 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4139void acc::DataOp::addAsyncOperand(
4142 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4143 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4144 getAsyncOperandsMutable()));
4147void acc::DataOp::addWaitOnly(
MLIRContext *context,
4149 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4150 effectiveDeviceTypes));
4153void acc::DataOp::addWaitOperands(
4158 if (getWaitOperandsSegments())
4159 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4161 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4162 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4163 getWaitOperandsMutable(), segments));
4164 setWaitOperandsSegments(segments);
4167 if (getHasWaitDevnumAttr())
4168 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4171 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4173 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4180LogicalResult acc::ExitDataOp::verify() {
4184 if (getDataClauseOperands().empty())
4185 return emitError(
"at least one operand must be present in dataOperands on "
4186 "the exit data operation");
4190 if (getAsyncOperand() && getAsync())
4191 return emitError(
"async attribute cannot appear with asyncOperand");
4195 if (!getWaitOperands().empty() && getWait())
4196 return emitError(
"wait attribute cannot appear with waitOperands");
4198 if (getWaitDevnum() && getWaitOperands().empty())
4199 return emitError(
"wait_devnum cannot appear without waitOperands");
4204unsigned ExitDataOp::getNumDataOperands() {
4205 return getDataClauseOperands().size();
4208Value ExitDataOp::getDataOperand(
unsigned i) {
4209 unsigned numOptional = getIfCond() ? 1 : 0;
4210 numOptional += getAsyncOperand() ? 1 : 0;
4211 numOptional += getWaitDevnum() ? 1 : 0;
4212 return getOperand(getWaitOperands().size() + numOptional + i);
4217 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
4220void ExitDataOp::addAsyncOnly(
MLIRContext *context,
4222 assert(effectiveDeviceTypes.empty());
4223 assert(!getAsyncAttr());
4224 assert(!getAsyncOperand());
4226 setAsyncAttr(mlir::UnitAttr::get(context));
4229void ExitDataOp::addAsyncOperand(
4232 assert(effectiveDeviceTypes.empty());
4233 assert(!getAsyncAttr());
4234 assert(!getAsyncOperand());
4236 getAsyncOperandMutable().append(newValue);
4241 assert(effectiveDeviceTypes.empty());
4242 assert(!getWaitAttr());
4243 assert(getWaitOperands().empty());
4244 assert(!getWaitDevnum());
4246 setWaitAttr(mlir::UnitAttr::get(context));
4249void ExitDataOp::addWaitOperands(
4252 assert(effectiveDeviceTypes.empty());
4253 assert(!getWaitAttr());
4254 assert(getWaitOperands().empty());
4255 assert(!getWaitDevnum());
4260 getWaitDevnumMutable().append(newValues.front());
4261 newValues = newValues.drop_front();
4264 getWaitOperandsMutable().append(newValues);
4271LogicalResult acc::EnterDataOp::verify() {
4275 if (getDataClauseOperands().empty())
4276 return emitError(
"at least one operand must be present in dataOperands on "
4277 "the enter data operation");
4281 if (getAsyncOperand() && getAsync())
4282 return emitError(
"async attribute cannot appear with asyncOperand");
4286 if (!getWaitOperands().empty() && getWait())
4287 return emitError(
"wait attribute cannot appear with waitOperands");
4289 if (getWaitDevnum() && getWaitOperands().empty())
4290 return emitError(
"wait_devnum cannot appear without waitOperands");
4292 for (
mlir::Value operand : getDataClauseOperands())
4293 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4294 operand.getDefiningOp()))
4295 return emitError(
"expect data entry operation as defining op");
4300unsigned EnterDataOp::getNumDataOperands() {
4301 return getDataClauseOperands().size();
4304Value EnterDataOp::getDataOperand(
unsigned i) {
4305 unsigned numOptional = getIfCond() ? 1 : 0;
4306 numOptional += getAsyncOperand() ? 1 : 0;
4307 numOptional += getWaitDevnum() ? 1 : 0;
4308 return getOperand(getWaitOperands().size() + numOptional + i);
4313 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
4316void EnterDataOp::addAsyncOnly(
4318 assert(effectiveDeviceTypes.empty());
4319 assert(!getAsyncAttr());
4320 assert(!getAsyncOperand());
4322 setAsyncAttr(mlir::UnitAttr::get(context));
4325void EnterDataOp::addAsyncOperand(
4328 assert(effectiveDeviceTypes.empty());
4329 assert(!getAsyncAttr());
4330 assert(!getAsyncOperand());
4332 getAsyncOperandMutable().append(newValue);
4335void EnterDataOp::addWaitOnly(
MLIRContext *context,
4337 assert(effectiveDeviceTypes.empty());
4338 assert(!getWaitAttr());
4339 assert(getWaitOperands().empty());
4340 assert(!getWaitDevnum());
4342 setWaitAttr(mlir::UnitAttr::get(context));
4345void EnterDataOp::addWaitOperands(
4348 assert(effectiveDeviceTypes.empty());
4349 assert(!getWaitAttr());
4350 assert(getWaitOperands().empty());
4351 assert(!getWaitDevnum());
4356 getWaitDevnumMutable().append(newValues.front());
4357 newValues = newValues.drop_front();
4360 getWaitOperandsMutable().append(newValues);
4367LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
4373LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
4379LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4386 if (
Value writeVal = op.getWriteOpVal()) {
4395LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
4397LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4403AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4404 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4406 return dyn_cast<AtomicReadOp>(getSecondOp());
4409AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4410 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4412 return dyn_cast<AtomicWriteOp>(getSecondOp());
4415AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4416 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4418 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4421LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
4427template <
typename Op>
4430 bool requireAtLeastOneOperand =
true) {
4431 if (operands.empty() && requireAtLeastOneOperand)
4434 "at least one operand must appear on the declare operation");
4437 if (isa<BlockArgument>(operand) ||
4438 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4439 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4440 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4441 operand.getDefiningOp()))
4443 "expect valid declare data entry operation or acc.getdeviceptr "
4447 assert(var &&
"declare operands can only be data entry operations which "
4450 std::optional<mlir::acc::DataClause> dataClauseOptional{
4452 assert(dataClauseOptional.has_value() &&
4453 "declare operands can only be data entry operations which must have "
4455 (
void)dataClauseOptional;
4461LogicalResult acc::DeclareEnterOp::verify() {
4469LogicalResult acc::DeclareExitOp::verify() {
4480LogicalResult acc::DeclareOp::verify() {
4489 acc::DeviceType dtype) {
4490 unsigned parallelism = 0;
4491 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4492 parallelism += op.hasWorker(dtype) ? 1 : 0;
4493 parallelism += op.hasVector(dtype) ? 1 : 0;
4494 parallelism += op.hasSeq(dtype) ? 1 : 0;
4498LogicalResult acc::RoutineOp::verify() {
4499 unsigned baseParallelism =
4502 if (baseParallelism > 1)
4503 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4504 "be present at the same time";
4506 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4508 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4509 if (dtype == acc::DeviceType::None)
4513 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4514 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4515 "be present at the same time for device_type `"
4516 << acc::stringifyDeviceType(dtype) <<
"`";
4523 mlir::ArrayAttr &bindIdName,
4524 mlir::ArrayAttr &bindStrName,
4525 mlir::ArrayAttr &deviceIdTypes,
4526 mlir::ArrayAttr &deviceStrTypes) {
4533 mlir::Attribute newAttr;
4534 bool isSymbolRefAttr;
4535 auto parseResult = parser.parseAttribute(newAttr);
4536 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4537 bindIdNameAttrs.push_back(symbolRefAttr);
4538 isSymbolRefAttr = true;
4539 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4540 bindStrNameAttrs.push_back(stringAttr);
4541 isSymbolRefAttr =
false;
4546 if (isSymbolRefAttr) {
4547 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4548 parser.getContext(), mlir::acc::DeviceType::None));
4550 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4551 parser.getContext(), mlir::acc::DeviceType::None));
4554 if (isSymbolRefAttr) {
4555 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4556 parser.parseRSquare())
4559 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4560 parser.parseRSquare())
4568 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4569 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4570 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4571 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4577 std::optional<mlir::ArrayAttr> bindIdName,
4578 std::optional<mlir::ArrayAttr> bindStrName,
4579 std::optional<mlir::ArrayAttr> deviceIdTypes,
4580 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4587 allBindNames.append(bindIdName->begin(), bindIdName->end());
4588 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4593 allBindNames.append(bindStrName->begin(), bindStrName->end());
4594 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4598 if (!allBindNames.empty())
4599 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4600 [&](
const auto &pair) {
4601 p << std::get<0>(pair);
4607 mlir::ArrayAttr &gang,
4608 mlir::ArrayAttr &gangDim,
4609 mlir::ArrayAttr &gangDimDeviceTypes) {
4612 gangDimDeviceTypeAttrs;
4613 bool needCommaBeforeOperands =
false;
4617 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4618 parser.
getContext(), mlir::acc::DeviceType::None));
4619 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4626 if (parser.parseAttribute(gangAttrs.emplace_back()))
4633 needCommaBeforeOperands =
true;
4636 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4640 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4641 parser.parseColon() ||
4642 parser.parseAttribute(gangDimAttrs.emplace_back()))
4644 if (succeeded(parser.parseOptionalLSquare())) {
4645 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4646 parser.parseRSquare())
4649 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4650 parser.getContext(), mlir::acc::DeviceType::None));
4656 if (
failed(parser.parseRParen()))
4659 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4660 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4661 gangDimDeviceTypes =
4662 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4668 std::optional<mlir::ArrayAttr> gang,
4669 std::optional<mlir::ArrayAttr> gangDim,
4670 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4673 gang->size() == 1) {
4674 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4675 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4687 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4688 [&](
const auto &pair) {
4689 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4690 p << std::get<0>(pair);
4698 mlir::ArrayAttr &deviceTypes) {
4702 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4703 parser.
getContext(), mlir::acc::DeviceType::None));
4704 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4711 if (parser.parseAttribute(attributes.emplace_back()))
4719 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4725 std::optional<mlir::ArrayAttr> deviceTypes) {
4728 auto deviceTypeAttr =
4729 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4730 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4739 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4745bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4747bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4751bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4753bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4757bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4759bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4763std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4764RoutineOp::getBindNameValue() {
4765 return getBindNameValue(mlir::acc::DeviceType::None);
4768std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4769RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4771 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4772 auto attr = (*getBindIdName())[*pos];
4773 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4774 assert(symbolRefAttr &&
"expected SymbolRef");
4775 return symbolRefAttr;
4780 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4781 auto attr = (*getBindStrName())[*pos];
4782 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4783 assert(stringAttr &&
"expected String");
4788 return std::nullopt;
4791bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4793bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4797std::optional<int64_t> RoutineOp::getGangDimValue() {
4798 return getGangDimValue(mlir::acc::DeviceType::None);
4801std::optional<int64_t>
4802RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4804 return std::nullopt;
4805 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4806 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4807 return intAttr.getInt();
4809 return std::nullopt;
4814 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4815 effectiveDeviceTypes));
4820 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4821 effectiveDeviceTypes));
4826 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4827 effectiveDeviceTypes));
4832 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4833 effectiveDeviceTypes));
4842 if (getGangDimAttr())
4843 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4844 if (getGangDimDeviceTypeAttr())
4845 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4847 assert(dimValues.size() == deviceTypes.size());
4849 if (effectiveDeviceTypes.empty()) {
4850 dimValues.push_back(
4851 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4852 deviceTypes.push_back(
4853 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4855 for (DeviceType dt : effectiveDeviceTypes) {
4856 dimValues.push_back(
4857 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4858 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4861 assert(dimValues.size() == deviceTypes.size());
4863 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4864 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4867void RoutineOp::addBindStrName(
MLIRContext *context,
4869 mlir::StringAttr val) {
4870 unsigned before = getBindStrNameDeviceTypeAttr()
4871 ? getBindStrNameDeviceTypeAttr().size()
4874 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4875 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4876 unsigned after = getBindStrNameDeviceTypeAttr().size();
4879 if (getBindStrNameAttr())
4880 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4881 for (
unsigned i = 0; i < after - before; ++i)
4882 vals.push_back(val);
4884 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4887void RoutineOp::addBindIDName(
MLIRContext *context,
4889 mlir::SymbolRefAttr val) {
4891 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4893 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4894 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4895 unsigned after = getBindIdNameDeviceTypeAttr().size();
4898 if (getBindIdNameAttr())
4899 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4900 for (
unsigned i = 0; i < after - before; ++i)
4901 vals.push_back(val);
4903 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4910LogicalResult acc::InitOp::verify() {
4911 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4912 return emitOpError(
"cannot be nested in a compute operation");
4916void acc::InitOp::addDeviceType(
MLIRContext *context,
4917 mlir::acc::DeviceType deviceType) {
4919 if (getDeviceTypesAttr())
4920 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4922 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4923 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4930LogicalResult acc::ShutdownOp::verify() {
4931 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4932 return emitOpError(
"cannot be nested in a compute operation");
4936void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4937 mlir::acc::DeviceType deviceType) {
4939 if (getDeviceTypesAttr())
4940 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4942 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4943 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4950LogicalResult acc::SetOp::verify() {
4951 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4952 return emitOpError(
"cannot be nested in a compute operation");
4953 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4954 return emitOpError(
"at least one default_async, device_num, or device_type "
4955 "operand must appear");
4963LogicalResult acc::UpdateOp::verify() {
4965 if (getDataClauseOperands().empty())
4966 return emitError(
"at least one value must be present in dataOperands");
4969 getAsyncOperandsDeviceTypeAttr(),
4974 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4975 getWaitOperandsDeviceTypeAttr(),
"wait")))
4981 for (
mlir::Value operand : getDataClauseOperands())
4982 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4983 operand.getDefiningOp()))
4984 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4990unsigned UpdateOp::getNumDataOperands() {
4991 return getDataClauseOperands().size();
4994Value UpdateOp::getDataOperand(
unsigned i) {
4996 numOptional += getIfCond() ? 1 : 0;
4997 return getOperand(getWaitOperands().size() + numOptional + i);
5002 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
5005bool UpdateOp::hasAsyncOnly() {
5006 return hasAsyncOnly(mlir::acc::DeviceType::None);
5009bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
5014 return getAsyncValue(mlir::acc::DeviceType::None);
5017mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
5027bool UpdateOp::hasWaitOnly() {
5028 return hasWaitOnly(mlir::acc::DeviceType::None);
5031bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
5036 return getWaitValues(mlir::acc::DeviceType::None);
5040UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
5042 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
5043 getHasWaitDevnum(), deviceType);
5047 return getWaitDevnum(mlir::acc::DeviceType::None);
5050mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
5052 getWaitOperandsSegments(), getHasWaitDevnum(),
5058 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
5059 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
5062void UpdateOp::addAsyncOperand(
5065 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5066 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
5067 getAsyncOperandsMutable()));
5072 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
5073 effectiveDeviceTypes));
5076void UpdateOp::addWaitOperands(
5081 if (getWaitOperandsSegments())
5082 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
5084 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5085 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
5086 getWaitOperandsMutable(), segments));
5087 setWaitOperandsSegments(segments);
5090 if (getHasWaitDevnumAttr())
5091 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5094 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
5096 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5103LogicalResult acc::WaitOp::verify() {
5106 if (getAsyncOperand() && getAsync())
5107 return emitError(
"async attribute cannot appear with asyncOperand");
5109 if (getWaitDevnum() && getWaitOperands().empty())
5110 return emitError(
"wait_devnum cannot appear without waitOperands");
5115#define GET_OP_CLASSES
5116#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5118#define GET_ATTRDEF_CLASSES
5119#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5121#define GET_TYPEDEF_CLASSES
5122#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5133 .Case<ACC_DATA_ENTRY_OPS>(
5134 [&](
auto entry) {
return entry.getVarPtr(); })
5135 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5136 [&](
auto exit) {
return exit.getVarPtr(); })
5154 [&](
auto entry) {
return entry.getVarType(); })
5155 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5156 [&](
auto exit) {
return exit.getVarType(); })
5166 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5167 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
5177 [&](
auto dataClause) {
return dataClause.getAccVar(); })
5186 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
5196 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5198 dataClause.getBounds().begin(), dataClause.getBounds().end());
5210 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
5212 dataClause.getAsyncOperands().begin(),
5213 dataClause.getAsyncOperands().end());
5224 return dataClause.getAsyncOperandsDeviceTypeAttr();
5232 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
5239 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
5246std::optional<mlir::acc::DataClause>
5251 .Case<ACC_DATA_ENTRY_OPS>(
5252 [&](
auto entry) {
return entry.getDataClause(); })
5260 [&](
auto entry) {
return entry.getImplicit(); })
5269 [&](
auto entry) {
return entry.getDataClauseOperands(); })
5271 return dataOperands;
5279 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
5281 return dataOperands;
5288 [&](
auto entry) {
return entry.getRecipeAttr(); })
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
unsigned size() const
Returns the current size of the range.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
OperandRange operand_range
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static DerivedEffect * get()
static CurrentDeviceIdResource * get()
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
Base attribute class for language-specific variable information carried through the OpenACC type inte...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
#define ACC_COMPUTE_CONSTRUCT_OPS
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.