25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
41static bool isScalarLikeType(
Type type) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
56struct MemRefPointerLikeModel
57 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
59 return cast<T>(pointer).getElementType();
62 mlir::acc::VariableTypeCategory
65 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
72 return mlir::acc::VariableTypeCategory::uncategorized;
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
81 return mlir::acc::VariableTypeCategory::uncategorized;
85 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree)
const {
92 auto memrefTy = cast<MemRefType>(pointer);
96 if (memrefTy.hasStaticShape()) {
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
105 if (originalVar && originalVar.
getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
132 Type varType)
const {
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc =
nullptr;
143 while (currentValue) {
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
152 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
158 if (
auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
172 if (isa<memref::AllocaOp>(originalAlloc)) {
176 if (isa<memref::AllocOp>(originalAlloc)) {
178 memref::DeallocOp::create(builder, loc, memrefValue);
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
209 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType)
const {
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
220 auto memrefTy = memrefValue.
getType();
223 if (memrefTy.getRank() != 0)
226 return memref::LoadOp::create(builder, loc, memrefValue);
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
239 auto memrefTy = memrefValue.getType();
242 if (memrefTy.getRank() != 0)
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
249 bool isDeviceData(Type pointer, Value var)
const {
250 auto memrefTy = cast<T>(pointer);
251 Attribute memSpace = memrefTy.getMemorySpace();
252 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
256struct LLVMPointerPointerLikeModel
257 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
258 LLVM::LLVMPointerType> {
261 mlir::Value
genLoad(Type pointer, OpBuilder &builder, Location loc,
263 Type valueType)
const {
268 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
271 bool genStore(Type pointer, OpBuilder &builder, Location loc,
273 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
278struct MemrefAddressOfGlobalModel
279 :
public AddressOfGlobalOpInterface::ExternalModel<
280 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
281 SymbolRefAttr getSymbol(Operation *op)
const {
282 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
283 return getGlobalOp.getNameAttr();
287struct MemrefGlobalVariableModel
288 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
290 bool isConstant(Operation *op)
const {
291 auto globalOp = cast<memref::GlobalOp>(op);
292 return globalOp.getConstant();
295 Region *getInitRegion(Operation *op)
const {
300 bool isDeviceData(Operation *op)
const {
301 auto globalOp = cast<memref::GlobalOp>(op);
302 Attribute memSpace = globalOp.getType().getMemorySpace();
303 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
307struct GPULaunchOffloadRegionModel
308 :
public acc::OffloadRegionOpInterface::ExternalModel<
309 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
310 mlir::Region &getOffloadRegion(mlir::Operation *op)
const {
311 return cast<gpu::LaunchOp>(op).getBody();
319mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
320 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
323 if (existingDeviceTypes)
324 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
326 if (newDeviceTypes.empty())
327 deviceTypes.push_back(
328 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
330 for (DeviceType dt : newDeviceTypes)
331 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
333 return mlir::ArrayAttr::get(context, deviceTypes);
342mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
343 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
348 if (existingDeviceTypes)
349 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
351 if (newDeviceTypes.empty()) {
352 argCollection.
append(arguments);
353 segments.push_back(arguments.size());
354 deviceTypes.push_back(
355 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
358 for (DeviceType dt : newDeviceTypes) {
359 argCollection.
append(arguments);
360 segments.push_back(arguments.size());
361 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
364 return mlir::ArrayAttr::get(context, deviceTypes);
368mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
369 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
373 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
374 newDeviceTypes, arguments,
375 argCollection, segments);
383void OpenACCDialect::initialize() {
386#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
393#define GET_TYPEDEF_LIST
394#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
400 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
402 UnrankedMemRefType::attachInterface<
403 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
404 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
408 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
410 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
411 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*
getContext());
448void ParallelOp::getSuccessorRegions(
468void KernelEnvironmentOp::getSuccessorRegions(
488void HostDataOp::getSuccessorRegions(
503 if (getUnstructured()) {
536 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
540 mlir::acc::DeviceType deviceType) {
544 for (
auto attr : *arrayAttr) {
545 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
546 if (deviceTypeAttr.getValue() == deviceType)
554 std::optional<mlir::ArrayAttr> deviceTypes) {
559 llvm::interleaveComma(*deviceTypes, p,
565 mlir::acc::DeviceType deviceType) {
566 unsigned segmentIdx = 0;
567 for (
auto attr : segments) {
568 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
569 if (deviceTypeAttr.getValue() == deviceType)
570 return std::make_optional(segmentIdx);
580 mlir::acc::DeviceType deviceType) {
582 return range.take_front(0);
583 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
584 int32_t nbOperandsBefore = 0;
585 for (
unsigned i = 0; i < *pos; ++i)
586 nbOperandsBefore += (*segments)[i];
587 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
589 return range.take_front(0);
596 std::optional<mlir::ArrayAttr> hasWaitDevnum,
597 mlir::acc::DeviceType deviceType) {
600 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
601 if (hasWaitDevnum->getValue()[*pos])
612 std::optional<mlir::ArrayAttr> hasWaitDevnum,
613 mlir::acc::DeviceType deviceType) {
618 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
619 if (hasWaitDevnum && *hasWaitDevnum) {
620 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
621 if (boolAttr.getValue())
622 return range.drop_front(1);
628template <
typename Op>
630 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
632 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
637 op.hasAsyncOnly(dtype))
639 "asyncOnly attribute cannot appear with asyncOperand");
644 op.hasWaitOnly(dtype))
645 return op.
emitError(
"wait attribute cannot appear with waitOperands");
650template <
typename Op>
653 return op.
emitError(
"must have var operand");
656 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
657 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
658 return op.
emitError(
"var must be mappable or pointer-like");
661 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
662 op.getVarType() == op.getVar().getType())
663 return op.
emitError(
"varType must capture the element type of var");
668template <
typename Op>
670 if (op.getVar().getType() != op.getAccVar().getType())
671 return op.
emitError(
"input and output types must match");
676template <
typename Op>
678 if (op.getModifiers() != acc::DataClauseModifier::none)
679 return op.
emitError(
"no data clause modifiers are allowed");
683template <
typename Op>
686 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
688 "invalid data clause modifiers: " +
689 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
694template <
typename OpT,
typename RecipeOpT>
695static LogicalResult
checkRecipe(OpT op, llvm::StringRef operandName) {
700 !std::is_same_v<OpT, acc::ReductionOp>)
703 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
705 return op->emitOpError() <<
"recipe expected for " << operandName;
710 return op->emitOpError()
711 <<
"expected symbol reference " << operandRecipe <<
" to point to a "
712 << operandName <<
" declaration";
733 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
754 if (failed(parser.
parseType(accVarType)))
764 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
776 mlir::TypeAttr &varTypeAttr) {
777 if (failed(parser.
parseType(varPtrType)))
788 varTypeAttr = mlir::TypeAttr::get(varType);
793 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
794 varTypeAttr = mlir::TypeAttr::get(
795 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
797 varTypeAttr = mlir::TypeAttr::get(varPtrType);
804 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
812 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
813 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
815 if (typeToCheckAgainst != varType) {
823 mlir::SymbolRefAttr &recipeAttr) {
830 mlir::SymbolRefAttr recipeAttr) {
837LogicalResult acc::DataBoundsOp::verify() {
838 auto extent = getExtent();
839 auto upperbound = getUpperbound();
840 if (!extent && !upperbound)
841 return emitError(
"expected extent or upperbound.");
848LogicalResult acc::PrivateOp::verify() {
851 "data clause associated with private operation must match its intent");
865LogicalResult acc::FirstprivateOp::verify() {
867 return emitError(
"data clause associated with firstprivate operation must "
874 *
this,
"firstprivate")))
882LogicalResult acc::FirstprivateMapInitialOp::verify() {
884 return emitError(
"data clause associated with firstprivate operation must "
896LogicalResult acc::ReductionOp::verify() {
898 return emitError(
"data clause associated with reduction operation must "
905 *
this,
"reduction")))
913LogicalResult acc::DevicePtrOp::verify() {
915 return emitError(
"data clause associated with deviceptr operation must "
929LogicalResult acc::PresentOp::verify() {
932 "data clause associated with present operation must match its intent");
945LogicalResult acc::CopyinOp::verify() {
947 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
952 "data clause associated with copyin operation must match its intent"
953 " or specify original clause this operation was decomposed from");
959 acc::DataClauseModifier::always |
960 acc::DataClauseModifier::capture)))
965bool acc::CopyinOp::isCopyinReadonly() {
966 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
967 acc::bitEnumContainsAny(getModifiers(),
968 acc::DataClauseModifier::readonly);
974LogicalResult acc::CreateOp::verify() {
981 "data clause associated with create operation must match its intent"
982 " or specify original clause this operation was decomposed from");
990 acc::DataClauseModifier::always |
991 acc::DataClauseModifier::capture)))
996bool acc::CreateOp::isCreateZero() {
998 return getDataClause() == acc::DataClause::acc_create_zero ||
1000 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1006LogicalResult acc::NoCreateOp::verify() {
1008 return emitError(
"data clause associated with no_create operation must "
1009 "match its intent");
1022LogicalResult acc::AttachOp::verify() {
1025 "data clause associated with attach operation must match its intent");
1039LogicalResult acc::DeclareDeviceResidentOp::verify() {
1040 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
1041 return emitError(
"data clause associated with device_resident operation "
1042 "must match its intent");
1056LogicalResult acc::DeclareLinkOp::verify() {
1059 "data clause associated with link operation must match its intent");
1072LogicalResult acc::CopyoutOp::verify() {
1079 "data clause associated with copyout operation must match its intent"
1080 " or specify original clause this operation was decomposed from");
1082 return emitError(
"must have both host and device pointers");
1088 acc::DataClauseModifier::always |
1089 acc::DataClauseModifier::capture)))
1094bool acc::CopyoutOp::isCopyoutZero() {
1095 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1096 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1102LogicalResult acc::DeleteOp::verify() {
1111 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1114 "data clause associated with delete operation must match its intent"
1115 " or specify original clause this operation was decomposed from");
1117 return emitError(
"must have device pointer");
1121 acc::DataClauseModifier::readonly |
1122 acc::DataClauseModifier::always |
1123 acc::DataClauseModifier::capture)))
1131LogicalResult acc::DetachOp::verify() {
1136 "data clause associated with detach operation must match its intent"
1137 " or specify original clause this operation was decomposed from");
1139 return emitError(
"must have device pointer");
1148LogicalResult acc::UpdateHostOp::verify() {
1153 "data clause associated with host operation must match its intent"
1154 " or specify original clause this operation was decomposed from");
1156 return emitError(
"must have both host and device pointers");
1169LogicalResult acc::UpdateDeviceOp::verify() {
1173 "data clause associated with device operation must match its intent"
1174 " or specify original clause this operation was decomposed from");
1187LogicalResult acc::UseDeviceOp::verify() {
1191 "data clause associated with use_device operation must match its intent"
1192 " or specify original clause this operation was decomposed from");
1205LogicalResult acc::CacheOp::verify() {
1210 "data clause associated with cache operation must match its intent"
1211 " or specify original clause this operation was decomposed from");
1221bool acc::CacheOp::isCacheReadonly() {
1222 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1223 acc::bitEnumContainsAny(getModifiers(),
1224 acc::DataClauseModifier::readonly);
1227template <
typename StructureOp>
1229 unsigned nRegions = 1) {
1232 for (
unsigned i = 0; i < nRegions; ++i)
1235 for (
Region *region : regions)
1243 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1250template <
typename OpTy>
1252 using OpRewritePattern<OpTy>::OpRewritePattern;
1254 LogicalResult matchAndRewrite(OpTy op,
1255 PatternRewriter &rewriter)
const override {
1257 Value ifCond = op.getIfCond();
1261 IntegerAttr constAttr;
1264 if (constAttr.getInt())
1265 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1277 assert(region.
hasOneBlock() &&
"expected single-block region");
1289template <
typename OpTy>
1290struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1291 using OpRewritePattern<OpTy>::OpRewritePattern;
1293 LogicalResult matchAndRewrite(OpTy op,
1294 PatternRewriter &rewriter)
const override {
1296 Value ifCond = op.getIfCond();
1300 IntegerAttr constAttr;
1303 if (constAttr.getInt())
1304 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1314struct RemoveEmptyKernelEnvironment
1316 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1318 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1319 PatternRewriter &rewriter)
const override {
1320 assert(op->getNumRegions() == 1 &&
"expected op to have one region");
1331 if (
auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1332 for (
auto attr : deviceTypeAttr) {
1333 if (
auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1334 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1341 if (
auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1342 for (
auto attr : hasDevnumAttr) {
1343 if (
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1344 if (boolAttr.getValue())
1351 if (
auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1352 if (segmentsAttr.size() > 1)
1358 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1385 for (
Value bound : bounds) {
1386 argTypes.push_back(bound.getType());
1387 argLocs.push_back(loc);
1394 Value privatizedValue;
1400 if (isa<MappableType>(varType)) {
1401 auto mappableTy = cast<MappableType>(varType);
1402 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1403 privatizedValue = mappableTy.generatePrivateInit(
1404 builder, loc, typedVar, varName, bounds, {}, needsFree);
1405 if (!privatizedValue)
1408 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1409 auto pointerLikeTy = cast<PointerLikeType>(varType);
1411 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1412 blockArgVar, needsFree);
1413 if (!privatizedValue)
1418 acc::YieldOp::create(builder, loc, privatizedValue);
1433 for (
Value bound : bounds) {
1434 copyArgTypes.push_back(bound.getType());
1435 copyArgLocs.push_back(loc);
1442 bool isMappable = isa<MappableType>(varType);
1443 bool isPointerLike = isa<PointerLikeType>(varType);
1446 if (isMappable && !isPointerLike)
1450 if (isPointerLike) {
1451 auto pointerLikeTy = cast<PointerLikeType>(varType);
1456 if (!pointerLikeTy.genCopy(
1463 acc::TerminatorOp::create(builder, loc);
1477 for (
Value bound : bounds) {
1478 destroyArgTypes.push_back(bound.getType());
1479 destroyArgLocs.push_back(loc);
1483 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1487 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1488 if (isa<MappableType>(varType)) {
1489 auto mappableTy = cast<MappableType>(varType);
1490 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds))
1493 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1494 auto pointerLikeTy = cast<PointerLikeType>(varType);
1495 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1499 acc::TerminatorOp::create(builder, loc);
1510 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1512 if (optional && region.
empty())
1516 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1520 return op->
emitOpError() <<
"expects " << regionName
1523 << regionType <<
" type";
1526 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1527 if (yieldOp.getOperands().size() != 1 ||
1528 yieldOp.getOperands().getTypes()[0] != type)
1529 return op->
emitOpError() <<
"expects " << regionName
1531 "yield a value of the "
1532 << regionType <<
" type";
1538LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1540 "privatization",
"init",
getType(),
1544 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1550std::optional<PrivateRecipeOp>
1552 StringRef recipeName,
Type varType,
1555 bool isMappable = isa<MappableType>(varType);
1556 bool isPointerLike = isa<PointerLikeType>(varType);
1559 if (!isMappable && !isPointerLike)
1560 return std::nullopt;
1565 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1568 bool needsFree =
false;
1569 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1570 varName, bounds, needsFree))) {
1572 return std::nullopt;
1579 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1580 Value allocRes = yieldOp.getOperand(0);
1582 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1583 varType, allocRes, bounds))) {
1585 return std::nullopt;
1592std::optional<PrivateRecipeOp>
1594 StringRef recipeName,
1595 FirstprivateRecipeOp firstprivRecipe) {
1598 auto varType = firstprivRecipe.getType();
1599 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1603 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1606 if (!firstprivRecipe.getDestroyRegion().empty()) {
1608 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1618LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1620 "privatization",
"init",
getType(),
1624 if (getCopyRegion().empty())
1625 return emitOpError() <<
"expects non-empty copy region";
1630 return emitOpError() <<
"expects copy region with two arguments of the "
1631 "privatization type";
1633 if (getDestroyRegion().empty())
1637 "privatization",
"destroy",
1644std::optional<FirstprivateRecipeOp>
1646 StringRef recipeName,
Type varType,
1649 bool isMappable = isa<MappableType>(varType);
1650 bool isPointerLike = isa<PointerLikeType>(varType);
1653 if (!isMappable && !isPointerLike)
1654 return std::nullopt;
1659 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1662 bool needsFree =
false;
1663 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1664 varName, bounds, needsFree))) {
1666 return std::nullopt;
1670 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1673 return std::nullopt;
1680 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1681 Value allocRes = yieldOp.getOperand(0);
1683 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1684 varType, allocRes, bounds))) {
1686 return std::nullopt;
1697LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1703 if (getCombinerRegion().empty())
1704 return emitOpError() <<
"expects non-empty combiner region";
1706 Block &reductionBlock = getCombinerRegion().
front();
1710 return emitOpError() <<
"expects combiner region with the first two "
1711 <<
"arguments of the reduction type";
1713 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1714 if (yieldOp.getOperands().size() != 1 ||
1715 yieldOp.getOperands().getTypes()[0] !=
getType())
1716 return emitOpError() <<
"expects combiner region to yield a value "
1717 "of the reduction type";
1728template <
typename Op>
1732 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1733 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1734 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1735 operand.getDefiningOp()))
1737 "expect data entry/exit operation or acc.getdeviceptr "
1742template <
typename OpT,
typename RecipeOpT>
1745 llvm::StringRef operandName) {
1748 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1750 <<
"expected " << operandName <<
" as defining op";
1751 if (!set.insert(operand).second)
1753 << operandName <<
" operand appears more than once";
1758unsigned ParallelOp::getNumDataOperands() {
1759 return getReductionOperands().size() + getPrivateOperands().size() +
1760 getFirstprivateOperands().size() + getDataClauseOperands().size();
1763Value ParallelOp::getDataOperand(
unsigned i) {
1765 numOptional += getNumGangs().size();
1766 numOptional += getNumWorkers().size();
1767 numOptional += getVectorLength().size();
1768 numOptional += getIfCond() ? 1 : 0;
1769 numOptional += getSelfCond() ? 1 : 0;
1770 return getOperand(getWaitOperands().size() + numOptional + i);
1773template <
typename Op>
1776 llvm::StringRef keyword) {
1777 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1778 return op.
emitOpError() << keyword <<
" operands count must match "
1779 << keyword <<
" device_type count";
1783template <
typename Op>
1786 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1787 std::size_t numOperandsInSegments = 0;
1788 std::size_t nbOfSegments = 0;
1791 for (
auto segCount : segments.
asArrayRef()) {
1792 if (maxInSegment != 0 && segCount > maxInSegment)
1793 return op.
emitOpError() << keyword <<
" expects a maximum of "
1794 << maxInSegment <<
" values per segment";
1795 numOperandsInSegments += segCount;
1800 if ((numOperandsInSegments != operands.size()) ||
1801 (!deviceTypes && !operands.empty()))
1803 << keyword <<
" operand count does not match count in segments";
1804 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1806 << keyword <<
" segment count does not match device_type count";
1810LogicalResult acc::ParallelOp::verify() {
1812 mlir::acc::PrivateRecipeOp>(
1813 *
this, getPrivateOperands(),
"private")))
1816 mlir::acc::FirstprivateRecipeOp>(
1817 *
this, getFirstprivateOperands(),
"firstprivate")))
1820 mlir::acc::ReductionRecipeOp>(
1821 *
this, getReductionOperands(),
"reduction")))
1825 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1826 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1830 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1831 getWaitOperandsDeviceTypeAttr(),
"wait")))
1835 getNumWorkersDeviceTypeAttr(),
1840 getVectorLengthDeviceTypeAttr(),
1845 getAsyncOperandsDeviceTypeAttr(),
1858 mlir::acc::DeviceType deviceType) {
1861 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1866bool acc::ParallelOp::hasAsyncOnly() {
1867 return hasAsyncOnly(mlir::acc::DeviceType::None);
1870bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1875 return getAsyncValue(mlir::acc::DeviceType::None);
1878mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1883mlir::Value acc::ParallelOp::getNumWorkersValue() {
1884 return getNumWorkersValue(mlir::acc::DeviceType::None);
1888acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1893mlir::Value acc::ParallelOp::getVectorLengthValue() {
1894 return getVectorLengthValue(mlir::acc::DeviceType::None);
1898acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1900 getVectorLength(), deviceType);
1904 return getNumGangsValues(mlir::acc::DeviceType::None);
1908ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1910 getNumGangsSegments(), deviceType);
1913bool acc::ParallelOp::hasWaitOnly() {
1914 return hasWaitOnly(mlir::acc::DeviceType::None);
1917bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1922 return getWaitValues(mlir::acc::DeviceType::None);
1926ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1928 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1929 getHasWaitDevnum(), deviceType);
1933 return getWaitDevnum(mlir::acc::DeviceType::None);
1936mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1938 getWaitOperandsSegments(), getHasWaitDevnum(),
1953 odsBuilder, odsState, asyncOperands,
nullptr,
1954 nullptr, waitOperands,
nullptr,
1956 nullptr, numGangs,
nullptr,
1957 nullptr, numWorkers,
1958 nullptr, vectorLength,
1959 nullptr, ifCond, selfCond,
1960 nullptr, reductionOperands, gangPrivateOperands,
1961 gangFirstPrivateOperands, dataClauseOperands,
1965void acc::ParallelOp::addNumWorkersOperand(
1968 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1969 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1970 getNumWorkersMutable()));
1972void acc::ParallelOp::addVectorLengthOperand(
1975 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1976 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1977 getVectorLengthMutable()));
1980void acc::ParallelOp::addAsyncOnly(
1982 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1983 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1986void acc::ParallelOp::addAsyncOperand(
1989 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1990 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1991 getAsyncOperandsMutable()));
1994void acc::ParallelOp::addNumGangsOperands(
1998 if (getNumGangsSegments())
1999 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2001 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2002 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2003 getNumGangsMutable(), segments));
2005 setNumGangsSegments(segments);
2007void acc::ParallelOp::addWaitOnly(
2009 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2010 effectiveDeviceTypes));
2012void acc::ParallelOp::addWaitOperands(
2017 if (getWaitOperandsSegments())
2018 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2020 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2021 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2022 getWaitOperandsMutable(), segments));
2023 setWaitOperandsSegments(segments);
2026 if (getHasWaitDevnumAttr())
2027 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2030 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2032 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2035void acc::ParallelOp::addPrivatization(
MLIRContext *context,
2036 mlir::acc::PrivateOp op,
2037 mlir::acc::PrivateRecipeOp recipe) {
2038 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2039 getPrivateOperandsMutable().append(op.getResult());
2042void acc::ParallelOp::addFirstPrivatization(
2043 MLIRContext *context, mlir::acc::FirstprivateOp op,
2044 mlir::acc::FirstprivateRecipeOp recipe) {
2045 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2046 getFirstprivateOperandsMutable().append(op.getResult());
2049void acc::ParallelOp::addReduction(
MLIRContext *context,
2050 mlir::acc::ReductionOp op,
2051 mlir::acc::ReductionRecipeOp recipe) {
2052 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2053 getReductionOperandsMutable().append(op.getResult());
2068 int32_t crtOperandsSize = operands.size();
2071 if (parser.parseOperand(operands.emplace_back()) ||
2072 parser.parseColonType(types.emplace_back()))
2077 seg.push_back(operands.size() - crtOperandsSize);
2087 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2088 parser.
getContext(), mlir::acc::DeviceType::None));
2094 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2101 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2102 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2103 p <<
" [" << attr <<
"]";
2108 std::optional<mlir::ArrayAttr> deviceTypes,
2109 std::optional<mlir::DenseI32ArrayAttr> segments) {
2111 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2113 llvm::interleaveComma(
2114 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2115 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2135 int32_t crtOperandsSize = operands.size();
2139 if (parser.parseOperand(operands.emplace_back()) ||
2140 parser.parseColonType(types.emplace_back()))
2146 seg.push_back(operands.size() - crtOperandsSize);
2156 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2157 parser.
getContext(), mlir::acc::DeviceType::None));
2163 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
2172 std::optional<mlir::DenseI32ArrayAttr> segments) {
2174 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2176 llvm::interleaveComma(
2177 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2178 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2191 mlir::ArrayAttr &keywordOnly) {
2195 bool needCommaBeforeOperands =
false;
2199 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2200 parser.
getContext(), mlir::acc::DeviceType::None));
2201 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2208 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2215 needCommaBeforeOperands =
true;
2218 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2225 int32_t crtOperandsSize = operands.size();
2237 if (parser.parseOperand(operands.emplace_back()) ||
2238 parser.parseColonType(types.emplace_back()))
2244 seg.push_back(operands.size() - crtOperandsSize);
2254 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2255 parser.
getContext(), mlir::acc::DeviceType::None));
2262 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2263 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2265 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2273 if (attrs->size() != 1)
2275 if (
auto deviceTypeAttr =
2276 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2277 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2283 std::optional<mlir::ArrayAttr> deviceTypes,
2284 std::optional<mlir::DenseI32ArrayAttr> segments,
2285 std::optional<mlir::ArrayAttr> hasDevNum,
2286 std::optional<mlir::ArrayAttr> keywordOnly) {
2299 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2301 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2302 if (boolAttr && boolAttr.getValue())
2304 llvm::interleaveComma(
2305 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2306 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2323 if (parser.parseOperand(operands.emplace_back()) ||
2324 parser.parseColonType(types.emplace_back()))
2326 if (succeeded(parser.parseOptionalLSquare())) {
2327 if (parser.parseAttribute(attributes.emplace_back()) ||
2328 parser.parseRSquare())
2331 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2332 parser.getContext(), mlir::acc::DeviceType::None));
2339 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2346 std::optional<mlir::ArrayAttr> deviceTypes) {
2349 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2350 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2359 mlir::ArrayAttr &keywordOnlyDeviceType) {
2362 bool needCommaBeforeOperands =
false;
2366 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2367 parser.
getContext(), mlir::acc::DeviceType::None));
2368 keywordOnlyDeviceType =
2369 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2377 if (parser.parseAttribute(
2378 keywordOnlyDeviceTypeAttributes.emplace_back()))
2385 needCommaBeforeOperands =
true;
2388 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2393 if (parser.parseOperand(operands.emplace_back()) ||
2394 parser.parseColonType(types.emplace_back()))
2396 if (succeeded(parser.parseOptionalLSquare())) {
2397 if (parser.parseAttribute(attributes.emplace_back()) ||
2398 parser.parseRSquare())
2401 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2402 parser.getContext(), mlir::acc::DeviceType::None));
2408 if (
failed(parser.parseRParen()))
2413 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2420 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2422 if (operands.begin() == operands.end() &&
2438 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2439 mlir::Type &operandType, mlir::UnitAttr &attr) {
2442 attr = mlir::UnitAttr::get(parser.
getContext());
2452 if (failed(parser.
parseType(operandType)))
2462 std::optional<mlir::Value> operand,
2464 mlir::UnitAttr attr) {
2481 attr = mlir::UnitAttr::get(parser.
getContext());
2486 if (parser.parseOperand(operands.emplace_back()))
2494 if (parser.parseType(types.emplace_back()))
2509 mlir::UnitAttr attr) {
2514 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2516 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2522 mlir::acc::CombinedConstructsTypeAttr &attr) {
2524 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2525 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2527 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2528 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2530 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2531 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2534 "expected compute construct name");
2542 mlir::acc::CombinedConstructsTypeAttr attr) {
2544 switch (attr.getValue()) {
2545 case mlir::acc::CombinedConstructsType::KernelsLoop:
2548 case mlir::acc::CombinedConstructsType::ParallelLoop:
2551 case mlir::acc::CombinedConstructsType::SerialLoop:
2562unsigned SerialOp::getNumDataOperands() {
2563 return getReductionOperands().size() + getPrivateOperands().size() +
2564 getFirstprivateOperands().size() + getDataClauseOperands().size();
2567Value SerialOp::getDataOperand(
unsigned i) {
2569 numOptional += getIfCond() ? 1 : 0;
2570 numOptional += getSelfCond() ? 1 : 0;
2571 return getOperand(getWaitOperands().size() + numOptional + i);
2574bool acc::SerialOp::hasAsyncOnly() {
2575 return hasAsyncOnly(mlir::acc::DeviceType::None);
2578bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2583 return getAsyncValue(mlir::acc::DeviceType::None);
2586mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2591bool acc::SerialOp::hasWaitOnly() {
2592 return hasWaitOnly(mlir::acc::DeviceType::None);
2595bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2600 return getWaitValues(mlir::acc::DeviceType::None);
2604SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2606 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2607 getHasWaitDevnum(), deviceType);
2611 return getWaitDevnum(mlir::acc::DeviceType::None);
2614mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2616 getWaitOperandsSegments(), getHasWaitDevnum(),
2620LogicalResult acc::SerialOp::verify() {
2622 mlir::acc::PrivateRecipeOp>(
2623 *
this, getPrivateOperands(),
"private")))
2626 mlir::acc::FirstprivateRecipeOp>(
2627 *
this, getFirstprivateOperands(),
"firstprivate")))
2630 mlir::acc::ReductionRecipeOp>(
2631 *
this, getReductionOperands(),
"reduction")))
2635 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2636 getWaitOperandsDeviceTypeAttr(),
"wait")))
2640 getAsyncOperandsDeviceTypeAttr(),
2650void acc::SerialOp::addAsyncOnly(
2652 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2653 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2656void acc::SerialOp::addAsyncOperand(
2659 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2660 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2661 getAsyncOperandsMutable()));
2664void acc::SerialOp::addWaitOnly(
2666 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2667 effectiveDeviceTypes));
2669void acc::SerialOp::addWaitOperands(
2674 if (getWaitOperandsSegments())
2675 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2677 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2678 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2679 getWaitOperandsMutable(), segments));
2680 setWaitOperandsSegments(segments);
2683 if (getHasWaitDevnumAttr())
2684 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2687 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2689 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2692void acc::SerialOp::addPrivatization(
MLIRContext *context,
2693 mlir::acc::PrivateOp op,
2694 mlir::acc::PrivateRecipeOp recipe) {
2695 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2696 getPrivateOperandsMutable().append(op.getResult());
2699void acc::SerialOp::addFirstPrivatization(
2700 MLIRContext *context, mlir::acc::FirstprivateOp op,
2701 mlir::acc::FirstprivateRecipeOp recipe) {
2702 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2703 getFirstprivateOperandsMutable().append(op.getResult());
2706void acc::SerialOp::addReduction(
MLIRContext *context,
2707 mlir::acc::ReductionOp op,
2708 mlir::acc::ReductionRecipeOp recipe) {
2709 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2710 getReductionOperandsMutable().append(op.getResult());
2717unsigned KernelsOp::getNumDataOperands() {
2718 return getDataClauseOperands().size();
2721Value KernelsOp::getDataOperand(
unsigned i) {
2723 numOptional += getWaitOperands().size();
2724 numOptional += getNumGangs().size();
2725 numOptional += getNumWorkers().size();
2726 numOptional += getVectorLength().size();
2727 numOptional += getIfCond() ? 1 : 0;
2728 numOptional += getSelfCond() ? 1 : 0;
2729 return getOperand(numOptional + i);
2732bool acc::KernelsOp::hasAsyncOnly() {
2733 return hasAsyncOnly(mlir::acc::DeviceType::None);
2736bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2741 return getAsyncValue(mlir::acc::DeviceType::None);
2744mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2750 return getNumWorkersValue(mlir::acc::DeviceType::None);
2754acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2759mlir::Value acc::KernelsOp::getVectorLengthValue() {
2760 return getVectorLengthValue(mlir::acc::DeviceType::None);
2764acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2766 getVectorLength(), deviceType);
2770 return getNumGangsValues(mlir::acc::DeviceType::None);
2774KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2776 getNumGangsSegments(), deviceType);
2779bool acc::KernelsOp::hasWaitOnly() {
2780 return hasWaitOnly(mlir::acc::DeviceType::None);
2783bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2788 return getWaitValues(mlir::acc::DeviceType::None);
2792KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2794 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2795 getHasWaitDevnum(), deviceType);
2799 return getWaitDevnum(mlir::acc::DeviceType::None);
2802mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2804 getWaitOperandsSegments(), getHasWaitDevnum(),
2808LogicalResult acc::KernelsOp::verify() {
2810 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2811 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2815 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2816 getWaitOperandsDeviceTypeAttr(),
"wait")))
2820 getNumWorkersDeviceTypeAttr(),
2825 getVectorLengthDeviceTypeAttr(),
2830 getAsyncOperandsDeviceTypeAttr(),
2840void acc::KernelsOp::addPrivatization(
MLIRContext *context,
2841 mlir::acc::PrivateOp op,
2842 mlir::acc::PrivateRecipeOp recipe) {
2843 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2844 getPrivateOperandsMutable().append(op.getResult());
2847void acc::KernelsOp::addFirstPrivatization(
2848 MLIRContext *context, mlir::acc::FirstprivateOp op,
2849 mlir::acc::FirstprivateRecipeOp recipe) {
2850 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2851 getFirstprivateOperandsMutable().append(op.getResult());
2854void acc::KernelsOp::addReduction(
MLIRContext *context,
2855 mlir::acc::ReductionOp op,
2856 mlir::acc::ReductionRecipeOp recipe) {
2857 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2858 getReductionOperandsMutable().append(op.getResult());
2861void acc::KernelsOp::addNumWorkersOperand(
2864 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2865 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2866 getNumWorkersMutable()));
2869void acc::KernelsOp::addVectorLengthOperand(
2872 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2873 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2874 getVectorLengthMutable()));
2876void acc::KernelsOp::addAsyncOnly(
2878 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2879 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2882void acc::KernelsOp::addAsyncOperand(
2885 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2886 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2887 getAsyncOperandsMutable()));
2890void acc::KernelsOp::addNumGangsOperands(
2894 if (getNumGangsSegmentsAttr())
2895 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2897 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2898 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2899 getNumGangsMutable(), segments));
2901 setNumGangsSegments(segments);
2904void acc::KernelsOp::addWaitOnly(
2906 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2907 effectiveDeviceTypes));
2909void acc::KernelsOp::addWaitOperands(
2914 if (getWaitOperandsSegments())
2915 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2917 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2918 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2919 getWaitOperandsMutable(), segments));
2920 setWaitOperandsSegments(segments);
2923 if (getHasWaitDevnumAttr())
2924 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2927 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2929 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2936LogicalResult acc::HostDataOp::verify() {
2937 if (getDataClauseOperands().empty())
2938 return emitError(
"at least one operand must appear on the host_data "
2942 for (
mlir::Value operand : getDataClauseOperands()) {
2944 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
2946 return emitError(
"expect data entry operation as defining op");
2949 if (!seenVars.insert(useDeviceOp.getVar()).second)
2950 return emitError(
"duplicate use_device variable");
2957 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2964void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2966 results.
add<RemoveEmptyKernelEnvironment>(context);
2978 bool &needCommaBetweenValues,
bool &newValue) {
2985 attributes.push_back(gangArgType);
2986 needCommaBetweenValues =
true;
2997 mlir::ArrayAttr &gangOnlyDeviceType) {
3002 bool needCommaBetweenValues =
false;
3003 bool needCommaBeforeOperands =
false;
3007 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3008 parser.
getContext(), mlir::acc::DeviceType::None));
3009 gangOnlyDeviceType =
3010 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
3018 if (parser.parseAttribute(
3019 gangOnlyDeviceTypeAttributes.emplace_back()))
3026 needCommaBeforeOperands =
true;
3029 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3030 mlir::acc::GangArgType::Num);
3031 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
3032 mlir::acc::GangArgType::Dim);
3033 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3034 parser.
getContext(), mlir::acc::GangArgType::Static);
3037 if (needCommaBeforeOperands) {
3038 needCommaBeforeOperands =
false;
3045 int32_t crtOperandsSize = gangOperands.size();
3047 bool newValue =
false;
3048 bool needValue =
false;
3049 if (needCommaBetweenValues) {
3057 gangOperands, gangOperandsType,
3058 gangArgTypeAttributes, argNum,
3059 needCommaBetweenValues, newValue)))
3062 gangOperands, gangOperandsType,
3063 gangArgTypeAttributes, argDim,
3064 needCommaBetweenValues, newValue)))
3066 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3067 gangOperands, gangOperandsType,
3068 gangArgTypeAttributes, argStatic,
3069 needCommaBetweenValues, newValue)))
3072 if (!newValue && needValue) {
3074 "new value expected after comma");
3082 if (gangOperands.empty())
3085 "expect at least one of num, dim or static values");
3091 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
3095 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3096 parser.
getContext(), mlir::acc::DeviceType::None));
3099 seg.push_back(gangOperands.size() - crtOperandsSize);
3107 gangArgTypeAttributes.end());
3108 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
3109 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
3112 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3113 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
3121 std::optional<mlir::ArrayAttr> gangArgTypes,
3122 std::optional<mlir::ArrayAttr> deviceTypes,
3123 std::optional<mlir::DenseI32ArrayAttr> segments,
3124 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3126 if (operands.begin() == operands.end() &&
3141 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
3143 llvm::interleaveComma(
3144 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
3145 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3146 (*gangArgTypes)[opIdx]);
3147 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3148 p << LoopOp::getGangNumKeyword();
3149 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3150 p << LoopOp::getGangDimKeyword();
3151 else if (gangArgTypeAttr.getValue() ==
3152 mlir::acc::GangArgType::Static)
3153 p << LoopOp::getGangStaticKeyword();
3154 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
3165 std::optional<mlir::ArrayAttr> segments,
3166 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3169 for (
auto attr : *segments) {
3170 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3171 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3179static std::optional<mlir::acc::DeviceType>
3181 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3183 return std::nullopt;
3184 for (
auto attr : deviceTypes) {
3185 auto deviceTypeAttr =
3186 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3187 if (!deviceTypeAttr)
3188 return mlir::acc::DeviceType::None;
3189 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3190 return deviceTypeAttr.getValue();
3192 return std::nullopt;
3195LogicalResult acc::LoopOp::verify() {
3196 if (getUpperbound().size() != getStep().size())
3197 return emitError() <<
"number of upperbounds expected to be the same as "
3200 if (getUpperbound().size() != getLowerbound().size())
3201 return emitError() <<
"number of upperbounds expected to be the same as "
3202 "number of lowerbounds";
3204 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3205 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3206 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3207 <<
" as upperbound size";
3210 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3211 return emitOpError() <<
"collapse device_type attr must be define when"
3212 <<
" collapse attr is present";
3214 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3215 getCollapseAttr().getValue().size() !=
3216 getCollapseDeviceTypeAttr().getValue().size())
3217 return emitOpError() <<
"collapse attribute count must match collapse"
3218 <<
" device_type count";
3219 if (
auto duplicateDeviceType =
checkDeviceTypes(getCollapseDeviceTypeAttr()))
3221 << acc::stringifyDeviceType(*duplicateDeviceType)
3222 <<
"` found in collapseDeviceType attribute";
3225 if (!getGangOperands().empty()) {
3226 if (!getGangOperandsArgType())
3227 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3228 <<
" when gang operands are present";
3230 if (getGangOperands().size() !=
3231 getGangOperandsArgTypeAttr().getValue().size())
3232 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3233 <<
" gangOperands count";
3235 if (getGangAttr()) {
3238 << acc::stringifyDeviceType(*duplicateDeviceType)
3239 <<
"` found in gang attribute";
3243 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3244 getGangOperandsDeviceTypeAttr(),
"gang")))
3250 << acc::stringifyDeviceType(*duplicateDeviceType)
3251 <<
"` found in worker attribute";
3252 if (
auto duplicateDeviceType =
3255 << acc::stringifyDeviceType(*duplicateDeviceType)
3256 <<
"` found in workerNumOperandsDeviceType attribute";
3258 getWorkerNumOperandsDeviceTypeAttr(),
3265 << acc::stringifyDeviceType(*duplicateDeviceType)
3266 <<
"` found in vector attribute";
3267 if (
auto duplicateDeviceType =
3270 << acc::stringifyDeviceType(*duplicateDeviceType)
3271 <<
"` found in vectorOperandsDeviceType attribute";
3273 getVectorOperandsDeviceTypeAttr(),
3278 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3279 getTileOperandsDeviceTypeAttr(),
"tile")))
3283 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3287 return emitError() <<
"only one of auto, independent, seq can be present "
3293 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3294 return attr.getValue() == mlir::acc::DeviceType::None;
3296 bool hasDefaultSeq =
3298 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3301 bool hasDefaultIndependent =
3302 getIndependentAttr()
3304 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3307 bool hasDefaultAuto =
3309 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3312 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3314 <<
"at least one of auto, independent, seq must be present";
3319 for (
auto attr : getSeqAttr()) {
3320 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3321 if (hasVector(deviceTypeAttr.getValue()) ||
3322 getVectorValue(deviceTypeAttr.getValue()) ||
3323 hasWorker(deviceTypeAttr.getValue()) ||
3324 getWorkerValue(deviceTypeAttr.getValue()) ||
3325 hasGang(deviceTypeAttr.getValue()) ||
3326 getGangValue(mlir::acc::GangArgType::Num,
3327 deviceTypeAttr.getValue()) ||
3328 getGangValue(mlir::acc::GangArgType::Dim,
3329 deviceTypeAttr.getValue()) ||
3330 getGangValue(mlir::acc::GangArgType::Static,
3331 deviceTypeAttr.getValue()))
3332 return emitError() <<
"gang, worker or vector cannot appear with seq";
3337 mlir::acc::PrivateRecipeOp>(
3338 *
this, getPrivateOperands(),
"private")))
3342 mlir::acc::FirstprivateRecipeOp>(
3343 *
this, getFirstprivateOperands(),
"firstprivate")))
3347 mlir::acc::ReductionRecipeOp>(
3348 *
this, getReductionOperands(),
"reduction")))
3351 if (getCombined().has_value() &&
3352 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3353 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3354 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3355 return emitError(
"unexpected combined constructs attribute");
3359 if (getRegion().empty())
3360 return emitError(
"expected non-empty body.");
3362 if (getUnstructured()) {
3363 if (!isContainerLike())
3365 "unstructured acc.loop must not have induction variables");
3366 }
else if (isContainerLike()) {
3370 uint64_t collapseCount = getCollapseValue().value_or(1);
3371 if (getCollapseAttr()) {
3372 for (
auto collapseEntry : getCollapseAttr()) {
3373 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3374 if (intAttr.getValue().getZExtValue() > collapseCount)
3375 collapseCount = intAttr.getValue().getZExtValue();
3383 bool foundSibling =
false;
3385 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3387 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3389 foundSibling =
true;
3394 expectedParent = op;
3397 if (collapseCount == 0)
3403 return emitError(
"found sibling loops inside container-like acc.loop");
3404 if (collapseCount != 0)
3405 return emitError(
"failed to find enough loop-like operations inside "
3406 "container-like acc.loop");
3412unsigned LoopOp::getNumDataOperands() {
3413 return getReductionOperands().size() + getPrivateOperands().size() +
3414 getFirstprivateOperands().size();
3417Value LoopOp::getDataOperand(
unsigned i) {
3418 unsigned numOptional =
3419 getLowerbound().size() + getUpperbound().size() + getStep().size();
3420 numOptional += getGangOperands().size();
3421 numOptional += getVectorOperands().size();
3422 numOptional += getWorkerNumOperands().size();
3423 numOptional += getTileOperands().size();
3424 numOptional += getCacheOperands().size();
3425 return getOperand(numOptional + i);
3428bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3430bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3434bool LoopOp::hasIndependent() {
3435 return hasIndependent(mlir::acc::DeviceType::None);
3438bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3442bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3444bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3449 return getVectorValue(mlir::acc::DeviceType::None);
3452mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3454 getVectorOperands(), deviceType);
3457bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3459bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3464 return getWorkerValue(mlir::acc::DeviceType::None);
3467mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3469 getWorkerNumOperands(), deviceType);
3472bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3474bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3479 return getTileValues(mlir::acc::DeviceType::None);
3483LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3485 getTileOperandsSegments(), deviceType);
3488std::optional<int64_t> LoopOp::getCollapseValue() {
3489 return getCollapseValue(mlir::acc::DeviceType::None);
3492std::optional<int64_t>
3493LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3494 if (!getCollapseAttr())
3495 return std::nullopt;
3496 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3498 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3499 return intAttr.getValue().getZExtValue();
3501 return std::nullopt;
3504mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3505 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3508mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3509 mlir::acc::DeviceType deviceType) {
3510 if (getGangOperands().empty())
3512 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3513 int32_t nbOperandsBefore = 0;
3514 for (
unsigned i = 0; i < *pos; ++i)
3515 nbOperandsBefore += (*getGangOperandsSegments())[i];
3518 .drop_front(nbOperandsBefore)
3519 .take_front((*getGangOperandsSegments())[*pos]);
3521 int32_t argTypeIdx = nbOperandsBefore;
3522 for (
auto value : values) {
3523 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3524 (*getGangOperandsArgType())[argTypeIdx]);
3525 if (gangArgTypeAttr.getValue() == gangArgType)
3533bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3535bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3540 return {&getRegion()};
3584 if (!regionArgs.empty()) {
3585 p << acc::LoopOp::getControlKeyword() <<
"(";
3586 llvm::interleaveComma(regionArgs, p,
3588 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3589 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3590 <<
" : " << stepType <<
") ";
3597 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3598 effectiveDeviceTypes));
3601void acc::LoopOp::addIndependent(
3603 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3604 context, getIndependentAttr(), effectiveDeviceTypes));
3609 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3610 effectiveDeviceTypes));
3613void acc::LoopOp::setCollapseForDeviceTypes(
3615 llvm::APInt value) {
3619 assert((getCollapseAttr() ==
nullptr) ==
3620 (getCollapseDeviceTypeAttr() ==
nullptr));
3621 assert(value.getBitWidth() == 64);
3623 if (getCollapseAttr()) {
3624 for (
const auto &existing :
3625 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3626 newValues.push_back(std::get<0>(existing));
3627 newDeviceTypes.push_back(std::get<1>(existing));
3631 if (effectiveDeviceTypes.empty()) {
3634 newValues.push_back(
3635 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3636 newDeviceTypes.push_back(
3637 acc::DeviceTypeAttr::get(context, DeviceType::None));
3639 for (DeviceType dt : effectiveDeviceTypes) {
3640 newValues.push_back(
3641 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3642 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3646 setCollapseAttr(ArrayAttr::get(context, newValues));
3647 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3650void acc::LoopOp::setTileForDeviceTypes(
3654 if (getTileOperandsSegments())
3655 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3657 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3658 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3659 getTileOperandsMutable(), segments));
3661 setTileOperandsSegments(segments);
3664void acc::LoopOp::addVectorOperand(
3667 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3668 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3669 newValue, getVectorOperandsMutable()));
3672void acc::LoopOp::addEmptyVector(
3674 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3675 effectiveDeviceTypes));
3678void acc::LoopOp::addWorkerNumOperand(
3681 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3682 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3683 newValue, getWorkerNumOperandsMutable()));
3686void acc::LoopOp::addEmptyWorker(
3688 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3689 effectiveDeviceTypes));
3692void acc::LoopOp::addEmptyGang(
3694 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3695 effectiveDeviceTypes));
3698bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3699 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3700 return attr.getValue() == dt;
3702 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3703 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3706 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3708 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3710 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3716bool acc::LoopOp::hasDefaultGangWorkerVector() {
3717 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3718 hasGang() || getGangValue(GangArgType::Num) ||
3719 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3723acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3724 if (hasSeq(deviceType))
3725 return LoopParMode::loop_seq;
3726 if (hasAuto(deviceType))
3727 return LoopParMode::loop_auto;
3728 if (hasIndependent(deviceType))
3729 return LoopParMode::loop_independent;
3731 return LoopParMode::loop_seq;
3733 return LoopParMode::loop_auto;
3734 assert(hasIndependent() &&
3735 "loop must have default auto, seq, or independent");
3736 return LoopParMode::loop_independent;
3739void acc::LoopOp::addGangOperands(
3744 getGangOperandsSegments())
3745 llvm::copy(*existingSegments, std::back_inserter(segments));
3747 unsigned beforeCount = segments.size();
3749 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3750 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3751 getGangOperandsMutable(), segments));
3753 setGangOperandsSegments(segments);
3760 unsigned numAdded = segments.size() - beforeCount;
3764 if (getGangOperandsArgTypeAttr())
3765 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3767 for (
auto i : llvm::index_range(0u, numAdded)) {
3768 llvm::transform(argTypes, std::back_inserter(gangTypes),
3769 [=](mlir::acc::GangArgType gangTy) {
3770 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3775 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3779void acc::LoopOp::addPrivatization(
MLIRContext *context,
3780 mlir::acc::PrivateOp op,
3781 mlir::acc::PrivateRecipeOp recipe) {
3782 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3783 getPrivateOperandsMutable().append(op.getResult());
3786void acc::LoopOp::addFirstPrivatization(
3787 MLIRContext *context, mlir::acc::FirstprivateOp op,
3788 mlir::acc::FirstprivateRecipeOp recipe) {
3789 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3790 getFirstprivateOperandsMutable().append(op.getResult());
3793void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3794 mlir::acc::ReductionRecipeOp recipe) {
3795 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3796 getReductionOperandsMutable().append(op.getResult());
3803LogicalResult acc::DataOp::verify() {
3808 return emitError(
"at least one operand or the default attribute "
3809 "must appear on the data operation");
3811 for (
mlir::Value operand : getDataClauseOperands())
3812 if (isa<BlockArgument>(operand) ||
3813 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3814 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3815 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3816 operand.getDefiningOp()))
3817 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3826unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3828Value DataOp::getDataOperand(
unsigned i) {
3829 unsigned numOptional = getIfCond() ? 1 : 0;
3831 numOptional += getWaitOperands().size();
3832 return getOperand(numOptional + i);
3835bool acc::DataOp::hasAsyncOnly() {
3836 return hasAsyncOnly(mlir::acc::DeviceType::None);
3839bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3844 return getAsyncValue(mlir::acc::DeviceType::None);
3847mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3852bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
3854bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3859 return getWaitValues(mlir::acc::DeviceType::None);
3863DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3865 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3866 getHasWaitDevnum(), deviceType);
3870 return getWaitDevnum(mlir::acc::DeviceType::None);
3873mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3875 getWaitOperandsSegments(), getHasWaitDevnum(),
3879void acc::DataOp::addAsyncOnly(
3881 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3882 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3885void acc::DataOp::addAsyncOperand(
3888 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3889 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3890 getAsyncOperandsMutable()));
3893void acc::DataOp::addWaitOnly(
MLIRContext *context,
3895 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3896 effectiveDeviceTypes));
3899void acc::DataOp::addWaitOperands(
3904 if (getWaitOperandsSegments())
3905 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3907 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3908 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3909 getWaitOperandsMutable(), segments));
3910 setWaitOperandsSegments(segments);
3913 if (getHasWaitDevnumAttr())
3914 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3917 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3919 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3926LogicalResult acc::ExitDataOp::verify() {
3930 if (getDataClauseOperands().empty())
3931 return emitError(
"at least one operand must be present in dataOperands on "
3932 "the exit data operation");
3936 if (getAsyncOperand() && getAsync())
3937 return emitError(
"async attribute cannot appear with asyncOperand");
3941 if (!getWaitOperands().empty() && getWait())
3942 return emitError(
"wait attribute cannot appear with waitOperands");
3944 if (getWaitDevnum() && getWaitOperands().empty())
3945 return emitError(
"wait_devnum cannot appear without waitOperands");
3950unsigned ExitDataOp::getNumDataOperands() {
3951 return getDataClauseOperands().size();
3954Value ExitDataOp::getDataOperand(
unsigned i) {
3955 unsigned numOptional = getIfCond() ? 1 : 0;
3956 numOptional += getAsyncOperand() ? 1 : 0;
3957 numOptional += getWaitDevnum() ? 1 : 0;
3958 return getOperand(getWaitOperands().size() + numOptional + i);
3963 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3966void ExitDataOp::addAsyncOnly(
MLIRContext *context,
3968 assert(effectiveDeviceTypes.empty());
3969 assert(!getAsyncAttr());
3970 assert(!getAsyncOperand());
3972 setAsyncAttr(mlir::UnitAttr::get(context));
3975void ExitDataOp::addAsyncOperand(
3978 assert(effectiveDeviceTypes.empty());
3979 assert(!getAsyncAttr());
3980 assert(!getAsyncOperand());
3982 getAsyncOperandMutable().append(newValue);
3987 assert(effectiveDeviceTypes.empty());
3988 assert(!getWaitAttr());
3989 assert(getWaitOperands().empty());
3990 assert(!getWaitDevnum());
3992 setWaitAttr(mlir::UnitAttr::get(context));
3995void ExitDataOp::addWaitOperands(
3998 assert(effectiveDeviceTypes.empty());
3999 assert(!getWaitAttr());
4000 assert(getWaitOperands().empty());
4001 assert(!getWaitDevnum());
4006 getWaitDevnumMutable().append(newValues.front());
4007 newValues = newValues.drop_front();
4010 getWaitOperandsMutable().append(newValues);
4017LogicalResult acc::EnterDataOp::verify() {
4021 if (getDataClauseOperands().empty())
4022 return emitError(
"at least one operand must be present in dataOperands on "
4023 "the enter data operation");
4027 if (getAsyncOperand() && getAsync())
4028 return emitError(
"async attribute cannot appear with asyncOperand");
4032 if (!getWaitOperands().empty() && getWait())
4033 return emitError(
"wait attribute cannot appear with waitOperands");
4035 if (getWaitDevnum() && getWaitOperands().empty())
4036 return emitError(
"wait_devnum cannot appear without waitOperands");
4038 for (
mlir::Value operand : getDataClauseOperands())
4039 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4040 operand.getDefiningOp()))
4041 return emitError(
"expect data entry operation as defining op");
4046unsigned EnterDataOp::getNumDataOperands() {
4047 return getDataClauseOperands().size();
4050Value EnterDataOp::getDataOperand(
unsigned i) {
4051 unsigned numOptional = getIfCond() ? 1 : 0;
4052 numOptional += getAsyncOperand() ? 1 : 0;
4053 numOptional += getWaitDevnum() ? 1 : 0;
4054 return getOperand(getWaitOperands().size() + numOptional + i);
4059 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
4062void EnterDataOp::addAsyncOnly(
4064 assert(effectiveDeviceTypes.empty());
4065 assert(!getAsyncAttr());
4066 assert(!getAsyncOperand());
4068 setAsyncAttr(mlir::UnitAttr::get(context));
4071void EnterDataOp::addAsyncOperand(
4074 assert(effectiveDeviceTypes.empty());
4075 assert(!getAsyncAttr());
4076 assert(!getAsyncOperand());
4078 getAsyncOperandMutable().append(newValue);
4081void EnterDataOp::addWaitOnly(
MLIRContext *context,
4083 assert(effectiveDeviceTypes.empty());
4084 assert(!getWaitAttr());
4085 assert(getWaitOperands().empty());
4086 assert(!getWaitDevnum());
4088 setWaitAttr(mlir::UnitAttr::get(context));
4091void EnterDataOp::addWaitOperands(
4094 assert(effectiveDeviceTypes.empty());
4095 assert(!getWaitAttr());
4096 assert(getWaitOperands().empty());
4097 assert(!getWaitDevnum());
4102 getWaitDevnumMutable().append(newValues.front());
4103 newValues = newValues.drop_front();
4106 getWaitOperandsMutable().append(newValues);
4113LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
4119LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
4125LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4132 if (
Value writeVal = op.getWriteOpVal()) {
4141LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
4143LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
4149AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4150 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4152 return dyn_cast<AtomicReadOp>(getSecondOp());
4155AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4156 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4158 return dyn_cast<AtomicWriteOp>(getSecondOp());
4161AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4162 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4164 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4167LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
4173template <
typename Op>
4176 bool requireAtLeastOneOperand =
true) {
4177 if (operands.empty() && requireAtLeastOneOperand)
4180 "at least one operand must appear on the declare operation");
4183 if (isa<BlockArgument>(operand) ||
4184 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4185 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4186 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4187 operand.getDefiningOp()))
4189 "expect valid declare data entry operation or acc.getdeviceptr "
4193 assert(var &&
"declare operands can only be data entry operations which "
4196 std::optional<mlir::acc::DataClause> dataClauseOptional{
4198 assert(dataClauseOptional.has_value() &&
4199 "declare operands can only be data entry operations which must have "
4201 (
void)dataClauseOptional;
4207LogicalResult acc::DeclareEnterOp::verify() {
4215LogicalResult acc::DeclareExitOp::verify() {
4226LogicalResult acc::DeclareOp::verify() {
4235 acc::DeviceType dtype) {
4236 unsigned parallelism = 0;
4237 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4238 parallelism += op.hasWorker(dtype) ? 1 : 0;
4239 parallelism += op.hasVector(dtype) ? 1 : 0;
4240 parallelism += op.hasSeq(dtype) ? 1 : 0;
4244LogicalResult acc::RoutineOp::verify() {
4245 unsigned baseParallelism =
4248 if (baseParallelism > 1)
4249 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4250 "be present at the same time";
4252 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4254 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4255 if (dtype == acc::DeviceType::None)
4259 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4260 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4261 "be present at the same time for device_type `"
4262 << acc::stringifyDeviceType(dtype) <<
"`";
4269 mlir::ArrayAttr &bindIdName,
4270 mlir::ArrayAttr &bindStrName,
4271 mlir::ArrayAttr &deviceIdTypes,
4272 mlir::ArrayAttr &deviceStrTypes) {
4279 mlir::Attribute newAttr;
4280 bool isSymbolRefAttr;
4281 auto parseResult = parser.parseAttribute(newAttr);
4282 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4283 bindIdNameAttrs.push_back(symbolRefAttr);
4284 isSymbolRefAttr = true;
4285 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4286 bindStrNameAttrs.push_back(stringAttr);
4287 isSymbolRefAttr =
false;
4292 if (isSymbolRefAttr) {
4293 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4294 parser.getContext(), mlir::acc::DeviceType::None));
4296 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4297 parser.getContext(), mlir::acc::DeviceType::None));
4300 if (isSymbolRefAttr) {
4301 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4302 parser.parseRSquare())
4305 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4306 parser.parseRSquare())
4314 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4315 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4316 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4317 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4323 std::optional<mlir::ArrayAttr> bindIdName,
4324 std::optional<mlir::ArrayAttr> bindStrName,
4325 std::optional<mlir::ArrayAttr> deviceIdTypes,
4326 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4333 allBindNames.append(bindIdName->begin(), bindIdName->end());
4334 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4339 allBindNames.append(bindStrName->begin(), bindStrName->end());
4340 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4344 if (!allBindNames.empty())
4345 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4346 [&](
const auto &pair) {
4347 p << std::get<0>(pair);
4353 mlir::ArrayAttr &gang,
4354 mlir::ArrayAttr &gangDim,
4355 mlir::ArrayAttr &gangDimDeviceTypes) {
4358 gangDimDeviceTypeAttrs;
4359 bool needCommaBeforeOperands =
false;
4363 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4364 parser.
getContext(), mlir::acc::DeviceType::None));
4365 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4372 if (parser.parseAttribute(gangAttrs.emplace_back()))
4379 needCommaBeforeOperands =
true;
4382 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4386 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4387 parser.parseColon() ||
4388 parser.parseAttribute(gangDimAttrs.emplace_back()))
4390 if (succeeded(parser.parseOptionalLSquare())) {
4391 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4392 parser.parseRSquare())
4395 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4396 parser.getContext(), mlir::acc::DeviceType::None));
4402 if (
failed(parser.parseRParen()))
4405 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4406 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4407 gangDimDeviceTypes =
4408 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4414 std::optional<mlir::ArrayAttr> gang,
4415 std::optional<mlir::ArrayAttr> gangDim,
4416 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4419 gang->size() == 1) {
4420 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4421 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4433 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4434 [&](
const auto &pair) {
4435 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4436 p << std::get<0>(pair);
4444 mlir::ArrayAttr &deviceTypes) {
4448 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4449 parser.
getContext(), mlir::acc::DeviceType::None));
4450 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4457 if (parser.parseAttribute(attributes.emplace_back()))
4465 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4471 std::optional<mlir::ArrayAttr> deviceTypes) {
4474 auto deviceTypeAttr =
4475 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4476 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4485 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4491bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4493bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4497bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4499bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4503bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4505bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4509std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4510RoutineOp::getBindNameValue() {
4511 return getBindNameValue(mlir::acc::DeviceType::None);
4514std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4515RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4518 return std::nullopt;
4521 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4522 auto attr = (*getBindIdName())[*pos];
4523 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4524 assert(symbolRefAttr &&
"expected SymbolRef");
4525 return symbolRefAttr;
4528 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4529 auto attr = (*getBindStrName())[*pos];
4530 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4531 assert(stringAttr &&
"expected String");
4535 return std::nullopt;
4538bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4540bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4544std::optional<int64_t> RoutineOp::getGangDimValue() {
4545 return getGangDimValue(mlir::acc::DeviceType::None);
4548std::optional<int64_t>
4549RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4551 return std::nullopt;
4552 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4553 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4554 return intAttr.getInt();
4556 return std::nullopt;
4561 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4562 effectiveDeviceTypes));
4567 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4568 effectiveDeviceTypes));
4573 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4574 effectiveDeviceTypes));
4579 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4580 effectiveDeviceTypes));
4589 if (getGangDimAttr())
4590 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4591 if (getGangDimDeviceTypeAttr())
4592 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4594 assert(dimValues.size() == deviceTypes.size());
4596 if (effectiveDeviceTypes.empty()) {
4597 dimValues.push_back(
4598 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4599 deviceTypes.push_back(
4600 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4602 for (DeviceType dt : effectiveDeviceTypes) {
4603 dimValues.push_back(
4604 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4605 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4608 assert(dimValues.size() == deviceTypes.size());
4610 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4611 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4614void RoutineOp::addBindStrName(
MLIRContext *context,
4616 mlir::StringAttr val) {
4617 unsigned before = getBindStrNameDeviceTypeAttr()
4618 ? getBindStrNameDeviceTypeAttr().size()
4621 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4622 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4623 unsigned after = getBindStrNameDeviceTypeAttr().size();
4626 if (getBindStrNameAttr())
4627 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4628 for (
unsigned i = 0; i < after - before; ++i)
4629 vals.push_back(val);
4631 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4634void RoutineOp::addBindIDName(
MLIRContext *context,
4636 mlir::SymbolRefAttr val) {
4638 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4640 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4641 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4642 unsigned after = getBindIdNameDeviceTypeAttr().size();
4645 if (getBindIdNameAttr())
4646 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4647 for (
unsigned i = 0; i < after - before; ++i)
4648 vals.push_back(val);
4650 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4657LogicalResult acc::InitOp::verify() {
4661 return emitOpError(
"cannot be nested in a compute operation");
4665void acc::InitOp::addDeviceType(
MLIRContext *context,
4666 mlir::acc::DeviceType deviceType) {
4668 if (getDeviceTypesAttr())
4669 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4671 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4672 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4679LogicalResult acc::ShutdownOp::verify() {
4683 return emitOpError(
"cannot be nested in a compute operation");
4687void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4688 mlir::acc::DeviceType deviceType) {
4690 if (getDeviceTypesAttr())
4691 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4693 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4694 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4701LogicalResult acc::SetOp::verify() {
4705 return emitOpError(
"cannot be nested in a compute operation");
4706 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4707 return emitOpError(
"at least one default_async, device_num, or device_type "
4708 "operand must appear");
4716LogicalResult acc::UpdateOp::verify() {
4718 if (getDataClauseOperands().empty())
4719 return emitError(
"at least one value must be present in dataOperands");
4722 getAsyncOperandsDeviceTypeAttr(),
4727 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4728 getWaitOperandsDeviceTypeAttr(),
"wait")))
4734 for (
mlir::Value operand : getDataClauseOperands())
4735 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4736 operand.getDefiningOp()))
4737 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4743unsigned UpdateOp::getNumDataOperands() {
4744 return getDataClauseOperands().size();
4747Value UpdateOp::getDataOperand(
unsigned i) {
4749 numOptional += getIfCond() ? 1 : 0;
4750 return getOperand(getWaitOperands().size() + numOptional + i);
4755 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
4758bool UpdateOp::hasAsyncOnly() {
4759 return hasAsyncOnly(mlir::acc::DeviceType::None);
4762bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4767 return getAsyncValue(mlir::acc::DeviceType::None);
4770mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4780bool UpdateOp::hasWaitOnly() {
4781 return hasWaitOnly(mlir::acc::DeviceType::None);
4784bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4789 return getWaitValues(mlir::acc::DeviceType::None);
4793UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4795 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4796 getHasWaitDevnum(), deviceType);
4800 return getWaitDevnum(mlir::acc::DeviceType::None);
4803mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4805 getWaitOperandsSegments(), getHasWaitDevnum(),
4811 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4812 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4815void UpdateOp::addAsyncOperand(
4818 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4819 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4820 getAsyncOperandsMutable()));
4825 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4826 effectiveDeviceTypes));
4829void UpdateOp::addWaitOperands(
4834 if (getWaitOperandsSegments())
4835 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4837 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4838 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4839 getWaitOperandsMutable(), segments));
4840 setWaitOperandsSegments(segments);
4843 if (getHasWaitDevnumAttr())
4844 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4847 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4849 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4856LogicalResult acc::WaitOp::verify() {
4859 if (getAsyncOperand() && getAsync())
4860 return emitError(
"async attribute cannot appear with asyncOperand");
4862 if (getWaitDevnum() && getWaitOperands().empty())
4863 return emitError(
"wait_devnum cannot appear without waitOperands");
4868#define GET_OP_CLASSES
4869#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4871#define GET_ATTRDEF_CLASSES
4872#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4874#define GET_TYPEDEF_CLASSES
4875#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4886 .Case<ACC_DATA_ENTRY_OPS>(
4887 [&](
auto entry) {
return entry.getVarPtr(); })
4888 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4889 [&](
auto exit) {
return exit.getVarPtr(); })
4907 [&](
auto entry) {
return entry.getVarType(); })
4908 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4909 [&](
auto exit) {
return exit.getVarType(); })
4919 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4920 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
4930 [&](
auto dataClause) {
return dataClause.getAccVar(); })
4939 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
4949 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4951 dataClause.getBounds().begin(), dataClause.getBounds().end());
4963 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4965 dataClause.getAsyncOperands().begin(),
4966 dataClause.getAsyncOperands().end());
4977 return dataClause.getAsyncOperandsDeviceTypeAttr();
4985 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
4992 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
4999std::optional<mlir::acc::DataClause>
5004 .Case<ACC_DATA_ENTRY_OPS>(
5005 [&](
auto entry) {
return entry.getDataClause(); })
5013 [&](
auto entry) {
return entry.getImplicit(); })
5022 [&](
auto entry) {
return entry.getDataClauseOperands(); })
5024 return dataOperands;
5032 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
5034 return dataOperands;
5041 [&](
auto entry) {
return entry.getRecipeAttr(); })
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
static bool isComputeOperation(Operation *op)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void getSingleRegionOpSuccessorRegions(Operation *op, Region ®ion, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > ®ions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This is a utility class for mapping one set of IR entities to another.
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
result_range getResults()
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.