23#include "llvm/ADT/SmallSet.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/LogicalResult.h"
31#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
32#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
33#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
35#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
39static bool isScalarLikeType(
Type type) {
47 if (!varName.empty()) {
48 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
54struct MemRefPointerLikeModel
55 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
57 return cast<T>(pointer).getElementType();
60 mlir::acc::VariableTypeCategory
63 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
64 return mappableTy.getTypeCategory(varPtr);
66 auto memrefTy = cast<T>(pointer);
67 if (!memrefTy.hasRank()) {
70 return mlir::acc::VariableTypeCategory::uncategorized;
73 if (memrefTy.getRank() == 0) {
74 if (isScalarLikeType(memrefTy.getElementType())) {
75 return mlir::acc::VariableTypeCategory::scalar;
79 return mlir::acc::VariableTypeCategory::uncategorized;
83 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
84 return mlir::acc::VariableTypeCategory::array;
87 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
88 StringRef varName, Type varType, Value originalVar,
89 bool &needsFree)
const {
90 auto memrefTy = cast<MemRefType>(pointer);
94 if (memrefTy.hasStaticShape()) {
96 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
97 attachVarNameAttr(allocaOp, builder, varName);
98 return allocaOp.getResult();
103 if (originalVar && originalVar.
getType() == memrefTy &&
104 memrefTy.hasRank()) {
105 SmallVector<Value> dynamicSizes;
106 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
107 if (memrefTy.isDynamicDim(i)) {
111 memref::DimOp::create(builder, loc, originalVar, indexValue);
112 dynamicSizes.push_back(dimSize);
119 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
120 attachVarNameAttr(allocOp, builder, varName);
121 return allocOp.getResult();
128 bool genFree(Type pointer, OpBuilder &builder, Location loc,
130 Type varType)
const {
133 Value valueToInspect = allocRes ? allocRes : memrefValue;
136 Value currentValue = valueToInspect;
137 Operation *originalAlloc =
nullptr;
141 while (currentValue) {
144 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
145 originalAlloc = definingOp;
150 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
151 currentValue = castOp.getSource();
156 if (
auto reinterpretCastOp =
157 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
158 currentValue = reinterpretCastOp.getSource();
170 if (isa<memref::AllocaOp>(originalAlloc)) {
174 if (isa<memref::AllocOp>(originalAlloc)) {
176 memref::DeallocOp::create(builder, loc, memrefValue);
185 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
189 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
190 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
196 if (destMemref && srcMemref &&
197 destMemref.getType().getElementType() ==
198 srcMemref.getType().getElementType() &&
199 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
200 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
208struct LLVMPointerPointerLikeModel
209 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
210 LLVM::LLVMPointerType> {
218mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
219 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
222 if (existingDeviceTypes)
223 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
225 if (newDeviceTypes.empty())
226 deviceTypes.push_back(
227 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
229 for (DeviceType dt : newDeviceTypes)
230 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
232 return mlir::ArrayAttr::get(context, deviceTypes);
241mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
242 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
247 if (existingDeviceTypes)
248 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
250 if (newDeviceTypes.empty()) {
251 argCollection.
append(arguments);
252 segments.push_back(arguments.size());
253 deviceTypes.push_back(
254 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
257 for (DeviceType dt : newDeviceTypes) {
258 argCollection.
append(arguments);
259 segments.push_back(arguments.size());
260 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
263 return mlir::ArrayAttr::get(context, deviceTypes);
267mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
268 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
272 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
273 newDeviceTypes, arguments,
274 argCollection, segments);
282void OpenACCDialect::initialize() {
285#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
288#define GET_ATTRDEF_LIST
289#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
292#define GET_TYPEDEF_LIST
293#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
299 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
301 UnrankedMemRefType::attachInterface<
302 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
303 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
312 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
316 mlir::acc::DeviceType deviceType) {
320 for (
auto attr : *arrayAttr) {
321 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
322 if (deviceTypeAttr.getValue() == deviceType)
330 std::optional<mlir::ArrayAttr> deviceTypes) {
335 llvm::interleaveComma(*deviceTypes, p,
341 mlir::acc::DeviceType deviceType) {
342 unsigned segmentIdx = 0;
343 for (
auto attr : segments) {
344 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
345 if (deviceTypeAttr.getValue() == deviceType)
346 return std::make_optional(segmentIdx);
356 mlir::acc::DeviceType deviceType) {
358 return range.take_front(0);
359 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
360 int32_t nbOperandsBefore = 0;
361 for (
unsigned i = 0; i < *pos; ++i)
362 nbOperandsBefore += (*segments)[i];
363 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
365 return range.take_front(0);
372 std::optional<mlir::ArrayAttr> hasWaitDevnum,
373 mlir::acc::DeviceType deviceType) {
376 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
377 if (hasWaitDevnum->getValue()[*pos])
388 std::optional<mlir::ArrayAttr> hasWaitDevnum,
389 mlir::acc::DeviceType deviceType) {
394 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
395 if (hasWaitDevnum && *hasWaitDevnum) {
396 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
397 if (boolAttr.getValue())
398 return range.drop_front(1);
404template <
typename Op>
406 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
408 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
413 op.hasAsyncOnly(dtype))
415 "asyncOnly attribute cannot appear with asyncOperand");
420 op.hasWaitOnly(dtype))
421 return op.
emitError(
"wait attribute cannot appear with waitOperands");
426template <
typename Op>
429 return op.
emitError(
"must have var operand");
432 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
433 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
434 return op.
emitError(
"var must be mappable or pointer-like");
437 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
438 op.getVarType() == op.getVar().getType())
439 return op.
emitError(
"varType must capture the element type of var");
444template <
typename Op>
446 if (op.getVar().getType() != op.getAccVar().getType())
447 return op.
emitError(
"input and output types must match");
452template <
typename Op>
454 if (op.getModifiers() != acc::DataClauseModifier::none)
455 return op.
emitError(
"no data clause modifiers are allowed");
459template <
typename Op>
462 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
464 "invalid data clause modifiers: " +
465 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
487 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
508 if (failed(parser.
parseType(accVarType)))
518 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
530 mlir::TypeAttr &varTypeAttr) {
531 if (failed(parser.
parseType(varPtrType)))
542 varTypeAttr = mlir::TypeAttr::get(varType);
547 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
548 varTypeAttr = mlir::TypeAttr::get(
549 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
551 varTypeAttr = mlir::TypeAttr::get(varPtrType);
558 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
566 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
567 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
569 if (typeToCheckAgainst != varType) {
579LogicalResult acc::DataBoundsOp::verify() {
580 auto extent = getExtent();
581 auto upperbound = getUpperbound();
582 if (!extent && !upperbound)
583 return emitError(
"expected extent or upperbound.");
590LogicalResult acc::PrivateOp::verify() {
593 "data clause associated with private operation must match its intent");
604LogicalResult acc::FirstprivateOp::verify() {
606 return emitError(
"data clause associated with firstprivate operation must "
618LogicalResult acc::FirstprivateMapInitialOp::verify() {
620 return emitError(
"data clause associated with firstprivate operation must "
632LogicalResult acc::ReductionOp::verify() {
634 return emitError(
"data clause associated with reduction operation must "
646LogicalResult acc::DevicePtrOp::verify() {
648 return emitError(
"data clause associated with deviceptr operation must "
662LogicalResult acc::PresentOp::verify() {
665 "data clause associated with present operation must match its intent");
678LogicalResult acc::CopyinOp::verify() {
680 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
685 "data clause associated with copyin operation must match its intent"
686 " or specify original clause this operation was decomposed from");
692 acc::DataClauseModifier::always |
693 acc::DataClauseModifier::capture)))
698bool acc::CopyinOp::isCopyinReadonly() {
699 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
700 acc::bitEnumContainsAny(getModifiers(),
701 acc::DataClauseModifier::readonly);
707LogicalResult acc::CreateOp::verify() {
714 "data clause associated with create operation must match its intent"
715 " or specify original clause this operation was decomposed from");
723 acc::DataClauseModifier::always |
724 acc::DataClauseModifier::capture)))
729bool acc::CreateOp::isCreateZero() {
731 return getDataClause() == acc::DataClause::acc_create_zero ||
733 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
739LogicalResult acc::NoCreateOp::verify() {
741 return emitError(
"data clause associated with no_create operation must "
755LogicalResult acc::AttachOp::verify() {
758 "data clause associated with attach operation must match its intent");
772LogicalResult acc::DeclareDeviceResidentOp::verify() {
773 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
774 return emitError(
"data clause associated with device_resident operation "
775 "must match its intent");
789LogicalResult acc::DeclareLinkOp::verify() {
792 "data clause associated with link operation must match its intent");
805LogicalResult acc::CopyoutOp::verify() {
812 "data clause associated with copyout operation must match its intent"
813 " or specify original clause this operation was decomposed from");
815 return emitError(
"must have both host and device pointers");
821 acc::DataClauseModifier::always |
822 acc::DataClauseModifier::capture)))
827bool acc::CopyoutOp::isCopyoutZero() {
828 return getDataClause() == acc::DataClause::acc_copyout_zero ||
829 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
835LogicalResult acc::DeleteOp::verify() {
844 getDataClause() != acc::DataClause::acc_declare_device_resident &&
847 "data clause associated with delete operation must match its intent"
848 " or specify original clause this operation was decomposed from");
850 return emitError(
"must have device pointer");
854 acc::DataClauseModifier::readonly |
855 acc::DataClauseModifier::always |
856 acc::DataClauseModifier::capture)))
864LogicalResult acc::DetachOp::verify() {
869 "data clause associated with detach operation must match its intent"
870 " or specify original clause this operation was decomposed from");
872 return emitError(
"must have device pointer");
881LogicalResult acc::UpdateHostOp::verify() {
886 "data clause associated with host operation must match its intent"
887 " or specify original clause this operation was decomposed from");
889 return emitError(
"must have both host and device pointers");
902LogicalResult acc::UpdateDeviceOp::verify() {
906 "data clause associated with device operation must match its intent"
907 " or specify original clause this operation was decomposed from");
920LogicalResult acc::UseDeviceOp::verify() {
924 "data clause associated with use_device operation must match its intent"
925 " or specify original clause this operation was decomposed from");
938LogicalResult acc::CacheOp::verify() {
943 "data clause associated with cache operation must match its intent"
944 " or specify original clause this operation was decomposed from");
954bool acc::CacheOp::isCacheReadonly() {
955 return getDataClause() == acc::DataClause::acc_cache_readonly ||
956 acc::bitEnumContainsAny(getModifiers(),
957 acc::DataClauseModifier::readonly);
960template <
typename StructureOp>
962 unsigned nRegions = 1) {
965 for (
unsigned i = 0; i < nRegions; ++i)
968 for (
Region *region : regions)
976 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
983template <
typename OpTy>
985 using OpRewritePattern<OpTy>::OpRewritePattern;
987 LogicalResult matchAndRewrite(OpTy op,
988 PatternRewriter &rewriter)
const override {
990 Value ifCond = op.getIfCond();
994 IntegerAttr constAttr;
997 if (constAttr.getInt())
998 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1010 assert(region.
hasOneBlock() &&
"expected single-block region");
1022template <
typename OpTy>
1023struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1024 using OpRewritePattern<OpTy>::OpRewritePattern;
1026 LogicalResult matchAndRewrite(OpTy op,
1027 PatternRewriter &rewriter)
const override {
1029 Value ifCond = op.getIfCond();
1033 IntegerAttr constAttr;
1036 if (constAttr.getInt())
1037 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1047struct RemoveEmptyKernelEnvironment
1049 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1051 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1052 PatternRewriter &rewriter)
const override {
1053 assert(op->getNumRegions() == 1 &&
"expected op to have one region");
1064 if (
auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1065 for (
auto attr : deviceTypeAttr) {
1066 if (
auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1067 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1074 if (
auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1075 for (
auto attr : hasDevnumAttr) {
1076 if (
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1077 if (boolAttr.getValue())
1084 if (
auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1085 if (segmentsAttr.size() > 1)
1091 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1118 for (
Value bound : bounds) {
1119 argTypes.push_back(bound.getType());
1120 argLocs.push_back(loc);
1127 Value privatizedValue;
1133 if (isa<MappableType>(varType)) {
1134 auto mappableTy = cast<MappableType>(varType);
1135 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1136 privatizedValue = mappableTy.generatePrivateInit(
1137 builder, loc, typedVar, varName, bounds, {}, needsFree);
1138 if (!privatizedValue)
1141 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1142 auto pointerLikeTy = cast<PointerLikeType>(varType);
1144 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1145 blockArgVar, needsFree);
1146 if (!privatizedValue)
1151 acc::YieldOp::create(builder, loc, privatizedValue);
1166 for (
Value bound : bounds) {
1167 copyArgTypes.push_back(bound.getType());
1168 copyArgLocs.push_back(loc);
1175 bool isMappable = isa<MappableType>(varType);
1176 bool isPointerLike = isa<PointerLikeType>(varType);
1179 if (isMappable && !isPointerLike)
1183 if (isPointerLike) {
1184 auto pointerLikeTy = cast<PointerLikeType>(varType);
1189 if (!pointerLikeTy.genCopy(
1196 acc::TerminatorOp::create(builder, loc);
1210 for (
Value bound : bounds) {
1211 destroyArgTypes.push_back(bound.getType());
1212 destroyArgLocs.push_back(loc);
1216 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1220 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1221 if (isa<MappableType>(varType)) {
1222 auto mappableTy = cast<MappableType>(varType);
1223 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
1226 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1227 auto pointerLikeTy = cast<PointerLikeType>(varType);
1228 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1232 acc::TerminatorOp::create(builder, loc);
1243 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1245 if (optional && region.
empty())
1249 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1253 return op->
emitOpError() <<
"expects " << regionName
1256 << regionType <<
" type";
1259 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1260 if (yieldOp.getOperands().size() != 1 ||
1261 yieldOp.getOperands().getTypes()[0] != type)
1262 return op->
emitOpError() <<
"expects " << regionName
1264 "yield a value of the "
1265 << regionType <<
" type";
1271LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1273 "privatization",
"init",
getType(),
1277 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1283std::optional<PrivateRecipeOp>
1285 StringRef recipeName,
Type varType,
1288 bool isMappable = isa<MappableType>(varType);
1289 bool isPointerLike = isa<PointerLikeType>(varType);
1292 if (!isMappable && !isPointerLike)
1293 return std::nullopt;
1298 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1301 bool needsFree =
false;
1302 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1303 varName, bounds, needsFree))) {
1305 return std::nullopt;
1312 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1313 Value allocRes = yieldOp.getOperand(0);
1315 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1316 varType, allocRes, bounds))) {
1318 return std::nullopt;
1329LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1331 "privatization",
"init",
getType(),
1335 if (getCopyRegion().empty())
1336 return emitOpError() <<
"expects non-empty copy region";
1341 return emitOpError() <<
"expects copy region with two arguments of the "
1342 "privatization type";
1344 if (getDestroyRegion().empty())
1348 "privatization",
"destroy",
1355std::optional<FirstprivateRecipeOp>
1357 StringRef recipeName,
Type varType,
1360 bool isMappable = isa<MappableType>(varType);
1361 bool isPointerLike = isa<PointerLikeType>(varType);
1364 if (!isMappable && !isPointerLike)
1365 return std::nullopt;
1370 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1373 bool needsFree =
false;
1374 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1375 varName, bounds, needsFree))) {
1377 return std::nullopt;
1381 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1384 return std::nullopt;
1391 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1392 Value allocRes = yieldOp.getOperand(0);
1394 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1395 varType, allocRes, bounds))) {
1397 return std::nullopt;
1408LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1414 if (getCombinerRegion().empty())
1415 return emitOpError() <<
"expects non-empty combiner region";
1417 Block &reductionBlock = getCombinerRegion().
front();
1421 return emitOpError() <<
"expects combiner region with the first two "
1422 <<
"arguments of the reduction type";
1424 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1425 if (yieldOp.getOperands().size() != 1 ||
1426 yieldOp.getOperands().getTypes()[0] !=
getType())
1427 return emitOpError() <<
"expects combiner region to yield a value "
1428 "of the reduction type";
1444 if (parser.parseAttribute(attributes.emplace_back()) ||
1445 parser.parseArrow() ||
1446 parser.parseOperand(operands.emplace_back()) ||
1447 parser.parseColonType(types.emplace_back()))
1454 symbols = ArrayAttr::get(parser.
getContext(), arrayAttr);
1461 std::optional<mlir::ArrayAttr> attributes) {
1462 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
1463 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
1464 << std::get<1>(it).getType();
1473template <
typename Op>
1477 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1478 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1479 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1480 operand.getDefiningOp()))
1482 "expect data entry/exit operation or acc.getdeviceptr "
1487template <
typename Op>
1491 llvm::StringRef symbolName,
bool checkOperandType =
true) {
1492 if (!operands.empty()) {
1493 if (!attributes || attributes->size() != operands.size())
1495 <<
"expected as many " << symbolName <<
" symbol reference as "
1496 << operandName <<
" operands";
1500 <<
"unexpected " << symbolName <<
" symbol reference";
1505 for (
auto args : llvm::zip(operands, *attributes)) {
1508 if (!set.insert(operand).second)
1510 << operandName <<
" operand appears more than once";
1513 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1517 <<
"expected symbol reference " << symbolRef <<
" to point to a "
1518 << operandName <<
" declaration";
1520 if (checkOperandType && decl.getType() && decl.getType() != varType)
1521 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
1522 <<
") to be the same type as " << operandName
1523 <<
" declaration (" << decl.getType() <<
")";
1529unsigned ParallelOp::getNumDataOperands() {
1530 return getReductionOperands().size() + getPrivateOperands().size() +
1531 getFirstprivateOperands().size() + getDataClauseOperands().size();
1534Value ParallelOp::getDataOperand(
unsigned i) {
1536 numOptional += getNumGangs().size();
1537 numOptional += getNumWorkers().size();
1538 numOptional += getVectorLength().size();
1539 numOptional += getIfCond() ? 1 : 0;
1540 numOptional += getSelfCond() ? 1 : 0;
1541 return getOperand(getWaitOperands().size() + numOptional + i);
1544template <
typename Op>
1547 llvm::StringRef keyword) {
1548 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1549 return op.
emitOpError() << keyword <<
" operands count must match "
1550 << keyword <<
" device_type count";
1554template <
typename Op>
1557 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1558 std::size_t numOperandsInSegments = 0;
1559 std::size_t nbOfSegments = 0;
1562 for (
auto segCount : segments.
asArrayRef()) {
1563 if (maxInSegment != 0 && segCount > maxInSegment)
1564 return op.
emitOpError() << keyword <<
" expects a maximum of "
1565 << maxInSegment <<
" values per segment";
1566 numOperandsInSegments += segCount;
1571 if ((numOperandsInSegments != operands.size()) ||
1572 (!deviceTypes && !operands.empty()))
1574 << keyword <<
" operand count does not match count in segments";
1575 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1577 << keyword <<
" segment count does not match device_type count";
1581LogicalResult acc::ParallelOp::verify() {
1583 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
1584 "privatizations",
false)))
1587 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1588 "firstprivate",
"firstprivatizations",
false)))
1591 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1592 "reductions",
false)))
1596 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1597 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1601 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1602 getWaitOperandsDeviceTypeAttr(),
"wait")))
1606 getNumWorkersDeviceTypeAttr(),
1611 getVectorLengthDeviceTypeAttr(),
1616 getAsyncOperandsDeviceTypeAttr(),
1629 mlir::acc::DeviceType deviceType) {
1632 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1637bool acc::ParallelOp::hasAsyncOnly() {
1638 return hasAsyncOnly(mlir::acc::DeviceType::None);
1641bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1646 return getAsyncValue(mlir::acc::DeviceType::None);
1649mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1654mlir::Value acc::ParallelOp::getNumWorkersValue() {
1655 return getNumWorkersValue(mlir::acc::DeviceType::None);
1659acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1664mlir::Value acc::ParallelOp::getVectorLengthValue() {
1665 return getVectorLengthValue(mlir::acc::DeviceType::None);
1669acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1671 getVectorLength(), deviceType);
1675 return getNumGangsValues(mlir::acc::DeviceType::None);
1679ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1681 getNumGangsSegments(), deviceType);
1684bool acc::ParallelOp::hasWaitOnly() {
1685 return hasWaitOnly(mlir::acc::DeviceType::None);
1688bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1693 return getWaitValues(mlir::acc::DeviceType::None);
1697ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1699 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1700 getHasWaitDevnum(), deviceType);
1704 return getWaitDevnum(mlir::acc::DeviceType::None);
1707mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1709 getWaitOperandsSegments(), getHasWaitDevnum(),
1725 odsBuilder, odsState, asyncOperands,
nullptr,
1726 nullptr, waitOperands,
nullptr,
1728 nullptr, numGangs,
nullptr,
1729 nullptr, numWorkers,
1730 nullptr, vectorLength,
1731 nullptr, ifCond, selfCond,
1732 nullptr, reductionOperands,
nullptr,
1733 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
1734 nullptr, dataClauseOperands,
1738void acc::ParallelOp::addNumWorkersOperand(
1741 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1742 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1743 getNumWorkersMutable()));
1745void acc::ParallelOp::addVectorLengthOperand(
1748 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1749 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1750 getVectorLengthMutable()));
1753void acc::ParallelOp::addAsyncOnly(
1755 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1756 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1759void acc::ParallelOp::addAsyncOperand(
1762 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1763 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1764 getAsyncOperandsMutable()));
1767void acc::ParallelOp::addNumGangsOperands(
1771 if (getNumGangsSegments())
1772 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1774 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1775 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1776 getNumGangsMutable(), segments));
1778 setNumGangsSegments(segments);
1780void acc::ParallelOp::addWaitOnly(
1782 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1783 effectiveDeviceTypes));
1785void acc::ParallelOp::addWaitOperands(
1790 if (getWaitOperandsSegments())
1791 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1793 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1794 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1795 getWaitOperandsMutable(), segments));
1796 setWaitOperandsSegments(segments);
1799 if (getHasWaitDevnumAttr())
1800 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1803 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1805 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1808void acc::ParallelOp::addPrivatization(
MLIRContext *context,
1809 mlir::acc::PrivateOp op,
1810 mlir::acc::PrivateRecipeOp recipe) {
1811 getPrivateOperandsMutable().append(op.getResult());
1815 if (getPrivatizationRecipesAttr())
1816 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
1819 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1820 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1823void acc::ParallelOp::addFirstPrivatization(
1824 MLIRContext *context, mlir::acc::FirstprivateOp op,
1825 mlir::acc::FirstprivateRecipeOp recipe) {
1826 getFirstprivateOperandsMutable().append(op.getResult());
1830 if (getFirstprivatizationRecipesAttr())
1831 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
1834 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1835 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1838void acc::ParallelOp::addReduction(
MLIRContext *context,
1839 mlir::acc::ReductionOp op,
1840 mlir::acc::ReductionRecipeOp recipe) {
1841 getReductionOperandsMutable().append(op.getResult());
1845 if (getReductionRecipesAttr())
1846 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
1849 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1850 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1865 int32_t crtOperandsSize = operands.size();
1868 if (parser.parseOperand(operands.emplace_back()) ||
1869 parser.parseColonType(types.emplace_back()))
1874 seg.push_back(operands.size() - crtOperandsSize);
1884 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1885 parser.
getContext(), mlir::acc::DeviceType::None));
1891 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
1898 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1899 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1900 p <<
" [" << attr <<
"]";
1905 std::optional<mlir::ArrayAttr> deviceTypes,
1906 std::optional<mlir::DenseI32ArrayAttr> segments) {
1908 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1910 llvm::interleaveComma(
1911 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1912 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1932 int32_t crtOperandsSize = operands.size();
1936 if (parser.parseOperand(operands.emplace_back()) ||
1937 parser.parseColonType(types.emplace_back()))
1943 seg.push_back(operands.size() - crtOperandsSize);
1953 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1954 parser.
getContext(), mlir::acc::DeviceType::None));
1960 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
1969 std::optional<mlir::DenseI32ArrayAttr> segments) {
1971 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1973 llvm::interleaveComma(
1974 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1975 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1988 mlir::ArrayAttr &keywordOnly) {
1992 bool needCommaBeforeOperands =
false;
1996 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
1997 parser.
getContext(), mlir::acc::DeviceType::None));
1998 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2005 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2012 needCommaBeforeOperands =
true;
2015 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2022 int32_t crtOperandsSize = operands.size();
2034 if (parser.parseOperand(operands.emplace_back()) ||
2035 parser.parseColonType(types.emplace_back()))
2041 seg.push_back(operands.size() - crtOperandsSize);
2051 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2052 parser.
getContext(), mlir::acc::DeviceType::None));
2059 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2060 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2062 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2070 if (attrs->size() != 1)
2072 if (
auto deviceTypeAttr =
2073 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2074 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2080 std::optional<mlir::ArrayAttr> deviceTypes,
2081 std::optional<mlir::DenseI32ArrayAttr> segments,
2082 std::optional<mlir::ArrayAttr> hasDevNum,
2083 std::optional<mlir::ArrayAttr> keywordOnly) {
2096 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2098 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2099 if (boolAttr && boolAttr.getValue())
2101 llvm::interleaveComma(
2102 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2103 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2120 if (parser.parseOperand(operands.emplace_back()) ||
2121 parser.parseColonType(types.emplace_back()))
2123 if (succeeded(parser.parseOptionalLSquare())) {
2124 if (parser.parseAttribute(attributes.emplace_back()) ||
2125 parser.parseRSquare())
2128 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2129 parser.getContext(), mlir::acc::DeviceType::None));
2136 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2143 std::optional<mlir::ArrayAttr> deviceTypes) {
2146 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2147 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2156 mlir::ArrayAttr &keywordOnlyDeviceType) {
2159 bool needCommaBeforeOperands =
false;
2163 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2164 parser.
getContext(), mlir::acc::DeviceType::None));
2165 keywordOnlyDeviceType =
2166 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2174 if (parser.parseAttribute(
2175 keywordOnlyDeviceTypeAttributes.emplace_back()))
2182 needCommaBeforeOperands =
true;
2185 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2190 if (parser.parseOperand(operands.emplace_back()) ||
2191 parser.parseColonType(types.emplace_back()))
2193 if (succeeded(parser.parseOptionalLSquare())) {
2194 if (parser.parseAttribute(attributes.emplace_back()) ||
2195 parser.parseRSquare())
2198 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2199 parser.getContext(), mlir::acc::DeviceType::None));
2205 if (
failed(parser.parseRParen()))
2210 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2217 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2219 if (operands.begin() == operands.end() &&
2235 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2236 mlir::Type &operandType, mlir::UnitAttr &attr) {
2239 attr = mlir::UnitAttr::get(parser.
getContext());
2249 if (failed(parser.
parseType(operandType)))
2259 std::optional<mlir::Value> operand,
2261 mlir::UnitAttr attr) {
2278 attr = mlir::UnitAttr::get(parser.
getContext());
2283 if (parser.parseOperand(operands.emplace_back()))
2291 if (parser.parseType(types.emplace_back()))
2306 mlir::UnitAttr attr) {
2311 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2313 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2319 mlir::acc::CombinedConstructsTypeAttr &attr) {
2321 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2322 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2324 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2325 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2327 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2328 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2331 "expected compute construct name");
2339 mlir::acc::CombinedConstructsTypeAttr attr) {
2341 switch (attr.getValue()) {
2342 case mlir::acc::CombinedConstructsType::KernelsLoop:
2345 case mlir::acc::CombinedConstructsType::ParallelLoop:
2348 case mlir::acc::CombinedConstructsType::SerialLoop:
2359unsigned SerialOp::getNumDataOperands() {
2360 return getReductionOperands().size() + getPrivateOperands().size() +
2361 getFirstprivateOperands().size() + getDataClauseOperands().size();
2364Value SerialOp::getDataOperand(
unsigned i) {
2366 numOptional += getIfCond() ? 1 : 0;
2367 numOptional += getSelfCond() ? 1 : 0;
2368 return getOperand(getWaitOperands().size() + numOptional + i);
2371bool acc::SerialOp::hasAsyncOnly() {
2372 return hasAsyncOnly(mlir::acc::DeviceType::None);
2375bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2380 return getAsyncValue(mlir::acc::DeviceType::None);
2383mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2388bool acc::SerialOp::hasWaitOnly() {
2389 return hasWaitOnly(mlir::acc::DeviceType::None);
2392bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2397 return getWaitValues(mlir::acc::DeviceType::None);
2401SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2403 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2404 getHasWaitDevnum(), deviceType);
2408 return getWaitDevnum(mlir::acc::DeviceType::None);
2411mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2413 getWaitOperandsSegments(), getHasWaitDevnum(),
2417LogicalResult acc::SerialOp::verify() {
2419 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
2420 "privatizations",
false)))
2423 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
2424 "firstprivate",
"firstprivatizations",
false)))
2427 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
2428 "reductions",
false)))
2432 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2433 getWaitOperandsDeviceTypeAttr(),
"wait")))
2437 getAsyncOperandsDeviceTypeAttr(),
2447void acc::SerialOp::addAsyncOnly(
2449 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2450 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2453void acc::SerialOp::addAsyncOperand(
2456 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2457 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2458 getAsyncOperandsMutable()));
2461void acc::SerialOp::addWaitOnly(
2463 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2464 effectiveDeviceTypes));
2466void acc::SerialOp::addWaitOperands(
2471 if (getWaitOperandsSegments())
2472 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2474 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2475 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2476 getWaitOperandsMutable(), segments));
2477 setWaitOperandsSegments(segments);
2480 if (getHasWaitDevnumAttr())
2481 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2484 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2486 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2489void acc::SerialOp::addPrivatization(
MLIRContext *context,
2490 mlir::acc::PrivateOp op,
2491 mlir::acc::PrivateRecipeOp recipe) {
2492 getPrivateOperandsMutable().append(op.getResult());
2496 if (getPrivatizationRecipesAttr())
2497 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
2500 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2501 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2504void acc::SerialOp::addFirstPrivatization(
2505 MLIRContext *context, mlir::acc::FirstprivateOp op,
2506 mlir::acc::FirstprivateRecipeOp recipe) {
2507 getFirstprivateOperandsMutable().append(op.getResult());
2511 if (getFirstprivatizationRecipesAttr())
2512 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
2515 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2516 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2519void acc::SerialOp::addReduction(
MLIRContext *context,
2520 mlir::acc::ReductionOp op,
2521 mlir::acc::ReductionRecipeOp recipe) {
2522 getReductionOperandsMutable().append(op.getResult());
2526 if (getReductionRecipesAttr())
2527 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
2530 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2531 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2538unsigned KernelsOp::getNumDataOperands() {
2539 return getDataClauseOperands().size();
2542Value KernelsOp::getDataOperand(
unsigned i) {
2544 numOptional += getWaitOperands().size();
2545 numOptional += getNumGangs().size();
2546 numOptional += getNumWorkers().size();
2547 numOptional += getVectorLength().size();
2548 numOptional += getIfCond() ? 1 : 0;
2549 numOptional += getSelfCond() ? 1 : 0;
2550 return getOperand(numOptional + i);
2553bool acc::KernelsOp::hasAsyncOnly() {
2554 return hasAsyncOnly(mlir::acc::DeviceType::None);
2557bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2562 return getAsyncValue(mlir::acc::DeviceType::None);
2565mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2571 return getNumWorkersValue(mlir::acc::DeviceType::None);
2575acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2580mlir::Value acc::KernelsOp::getVectorLengthValue() {
2581 return getVectorLengthValue(mlir::acc::DeviceType::None);
2585acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2587 getVectorLength(), deviceType);
2591 return getNumGangsValues(mlir::acc::DeviceType::None);
2595KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2597 getNumGangsSegments(), deviceType);
2600bool acc::KernelsOp::hasWaitOnly() {
2601 return hasWaitOnly(mlir::acc::DeviceType::None);
2604bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2609 return getWaitValues(mlir::acc::DeviceType::None);
2613KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2615 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2616 getHasWaitDevnum(), deviceType);
2620 return getWaitDevnum(mlir::acc::DeviceType::None);
2623mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2625 getWaitOperandsSegments(), getHasWaitDevnum(),
2629LogicalResult acc::KernelsOp::verify() {
2631 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2632 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2636 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2637 getWaitOperandsDeviceTypeAttr(),
"wait")))
2641 getNumWorkersDeviceTypeAttr(),
2646 getVectorLengthDeviceTypeAttr(),
2651 getAsyncOperandsDeviceTypeAttr(),
2661void acc::KernelsOp::addNumWorkersOperand(
2664 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2665 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2666 getNumWorkersMutable()));
2669void acc::KernelsOp::addVectorLengthOperand(
2672 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2673 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2674 getVectorLengthMutable()));
2676void acc::KernelsOp::addAsyncOnly(
2678 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2679 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2682void acc::KernelsOp::addAsyncOperand(
2685 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2686 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2687 getAsyncOperandsMutable()));
2690void acc::KernelsOp::addNumGangsOperands(
2694 if (getNumGangsSegmentsAttr())
2695 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2697 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2698 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2699 getNumGangsMutable(), segments));
2701 setNumGangsSegments(segments);
2704void acc::KernelsOp::addWaitOnly(
2706 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2707 effectiveDeviceTypes));
2709void acc::KernelsOp::addWaitOperands(
2714 if (getWaitOperandsSegments())
2715 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2717 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2718 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2719 getWaitOperandsMutable(), segments));
2720 setWaitOperandsSegments(segments);
2723 if (getHasWaitDevnumAttr())
2724 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2727 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2729 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2736LogicalResult acc::HostDataOp::verify() {
2737 if (getDataClauseOperands().empty())
2738 return emitError(
"at least one operand must appear on the host_data "
2741 for (
mlir::Value operand : getDataClauseOperands())
2742 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2743 return emitError(
"expect data entry operation as defining op");
2749 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2756void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2758 results.
add<RemoveEmptyKernelEnvironment>(context);
2770 bool &needCommaBetweenValues,
bool &newValue) {
2777 attributes.push_back(gangArgType);
2778 needCommaBetweenValues =
true;
2789 mlir::ArrayAttr &gangOnlyDeviceType) {
2794 bool needCommaBetweenValues =
false;
2795 bool needCommaBeforeOperands =
false;
2799 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2800 parser.
getContext(), mlir::acc::DeviceType::None));
2801 gangOnlyDeviceType =
2802 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
2810 if (parser.parseAttribute(
2811 gangOnlyDeviceTypeAttributes.emplace_back()))
2818 needCommaBeforeOperands =
true;
2821 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
2822 mlir::acc::GangArgType::Num);
2823 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
2824 mlir::acc::GangArgType::Dim);
2825 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2826 parser.
getContext(), mlir::acc::GangArgType::Static);
2829 if (needCommaBeforeOperands) {
2830 needCommaBeforeOperands =
false;
2837 int32_t crtOperandsSize = gangOperands.size();
2839 bool newValue =
false;
2840 bool needValue =
false;
2841 if (needCommaBetweenValues) {
2849 gangOperands, gangOperandsType,
2850 gangArgTypeAttributes, argNum,
2851 needCommaBetweenValues, newValue)))
2854 gangOperands, gangOperandsType,
2855 gangArgTypeAttributes, argDim,
2856 needCommaBetweenValues, newValue)))
2858 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2859 gangOperands, gangOperandsType,
2860 gangArgTypeAttributes, argStatic,
2861 needCommaBetweenValues, newValue)))
2864 if (!newValue && needValue) {
2866 "new value expected after comma");
2874 if (gangOperands.empty())
2877 "expect at least one of num, dim or static values");
2883 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
2887 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2888 parser.
getContext(), mlir::acc::DeviceType::None));
2891 seg.push_back(gangOperands.size() - crtOperandsSize);
2899 gangArgTypeAttributes.end());
2900 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
2901 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
2904 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2905 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
2913 std::optional<mlir::ArrayAttr> gangArgTypes,
2914 std::optional<mlir::ArrayAttr> deviceTypes,
2915 std::optional<mlir::DenseI32ArrayAttr> segments,
2916 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2918 if (operands.begin() == operands.end() &&
2933 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2935 llvm::interleaveComma(
2936 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2937 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2938 (*gangArgTypes)[opIdx]);
2939 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2940 p << LoopOp::getGangNumKeyword();
2941 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2942 p << LoopOp::getGangDimKeyword();
2943 else if (gangArgTypeAttr.getValue() ==
2944 mlir::acc::GangArgType::Static)
2945 p << LoopOp::getGangStaticKeyword();
2946 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
2957 std::optional<mlir::ArrayAttr> segments,
2958 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2961 for (
auto attr : *segments) {
2962 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2963 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2971 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2974 for (
auto attr : deviceTypes) {
2975 auto deviceTypeAttr =
2976 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
2977 if (!deviceTypeAttr)
2979 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
2985LogicalResult acc::LoopOp::verify() {
2986 if (getUpperbound().size() != getStep().size())
2987 return emitError() <<
"number of upperbounds expected to be the same as "
2990 if (getUpperbound().size() != getLowerbound().size())
2991 return emitError() <<
"number of upperbounds expected to be the same as "
2992 "number of lowerbounds";
2994 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
2995 (getUpperbound().size() != getInclusiveUpperbound()->size()))
2996 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
2997 <<
" as upperbound size";
3000 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3001 return emitOpError() <<
"collapse device_type attr must be define when"
3002 <<
" collapse attr is present";
3004 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3005 getCollapseAttr().getValue().size() !=
3006 getCollapseDeviceTypeAttr().getValue().size())
3007 return emitOpError() <<
"collapse attribute count must match collapse"
3008 <<
" device_type count";
3011 <<
"duplicate device_type found in collapseDeviceType attribute";
3014 if (!getGangOperands().empty()) {
3015 if (!getGangOperandsArgType())
3016 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3017 <<
" when gang operands are present";
3019 if (getGangOperands().size() !=
3020 getGangOperandsArgTypeAttr().getValue().size())
3021 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3022 <<
" gangOperands count";
3025 return emitOpError() <<
"duplicate device_type found in gang attribute";
3028 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3029 getGangOperandsDeviceTypeAttr(),
"gang")))
3034 return emitOpError() <<
"duplicate device_type found in worker attribute";
3036 return emitOpError() <<
"duplicate device_type found in "
3037 "workerNumOperandsDeviceType attribute";
3039 getWorkerNumOperandsDeviceTypeAttr(),
3045 return emitOpError() <<
"duplicate device_type found in vector attribute";
3047 return emitOpError() <<
"duplicate device_type found in "
3048 "vectorOperandsDeviceType attribute";
3050 getVectorOperandsDeviceTypeAttr(),
3055 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3056 getTileOperandsDeviceTypeAttr(),
"tile")))
3060 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3064 return emitError() <<
"only one of auto, independent, seq can be present "
3070 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3071 return attr.getValue() == mlir::acc::DeviceType::None;
3073 bool hasDefaultSeq =
3075 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3078 bool hasDefaultIndependent =
3079 getIndependentAttr()
3081 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3084 bool hasDefaultAuto =
3086 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3089 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3091 <<
"at least one of auto, independent, seq must be present";
3096 for (
auto attr : getSeqAttr()) {
3097 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3098 if (hasVector(deviceTypeAttr.getValue()) ||
3099 getVectorValue(deviceTypeAttr.getValue()) ||
3100 hasWorker(deviceTypeAttr.getValue()) ||
3101 getWorkerValue(deviceTypeAttr.getValue()) ||
3102 hasGang(deviceTypeAttr.getValue()) ||
3103 getGangValue(mlir::acc::GangArgType::Num,
3104 deviceTypeAttr.getValue()) ||
3105 getGangValue(mlir::acc::GangArgType::Dim,
3106 deviceTypeAttr.getValue()) ||
3107 getGangValue(mlir::acc::GangArgType::Static,
3108 deviceTypeAttr.getValue()))
3109 return emitError() <<
"gang, worker or vector cannot appear with seq";
3114 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
3115 "privatizations",
false)))
3119 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
3120 "firstprivate",
"firstprivatizations",
false)))
3124 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
3125 "reductions",
false)))
3128 if (getCombined().has_value() &&
3129 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3130 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3131 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3132 return emitError(
"unexpected combined constructs attribute");
3136 if (getRegion().empty())
3137 return emitError(
"expected non-empty body.");
3139 if (getUnstructured()) {
3140 if (!isContainerLike())
3142 "unstructured acc.loop must not have induction variables");
3143 }
else if (isContainerLike()) {
3147 uint64_t collapseCount = getCollapseValue().value_or(1);
3148 if (getCollapseAttr()) {
3149 for (
auto collapseEntry : getCollapseAttr()) {
3150 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3151 if (intAttr.getValue().getZExtValue() > collapseCount)
3152 collapseCount = intAttr.getValue().getZExtValue();
3160 bool foundSibling =
false;
3162 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3164 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3166 foundSibling =
true;
3171 expectedParent = op;
3174 if (collapseCount == 0)
3180 return emitError(
"found sibling loops inside container-like acc.loop");
3181 if (collapseCount != 0)
3182 return emitError(
"failed to find enough loop-like operations inside "
3183 "container-like acc.loop");
3189unsigned LoopOp::getNumDataOperands() {
3190 return getReductionOperands().size() + getPrivateOperands().size() +
3191 getFirstprivateOperands().size();
3194Value LoopOp::getDataOperand(
unsigned i) {
3195 unsigned numOptional =
3196 getLowerbound().size() + getUpperbound().size() + getStep().size();
3197 numOptional += getGangOperands().size();
3198 numOptional += getVectorOperands().size();
3199 numOptional += getWorkerNumOperands().size();
3200 numOptional += getTileOperands().size();
3201 numOptional += getCacheOperands().size();
3202 return getOperand(numOptional + i);
3205bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3207bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3211bool LoopOp::hasIndependent() {
3212 return hasIndependent(mlir::acc::DeviceType::None);
3215bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3219bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3221bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3226 return getVectorValue(mlir::acc::DeviceType::None);
3229mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3231 getVectorOperands(), deviceType);
3234bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3236bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3241 return getWorkerValue(mlir::acc::DeviceType::None);
3244mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3246 getWorkerNumOperands(), deviceType);
3249bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3251bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3256 return getTileValues(mlir::acc::DeviceType::None);
3260LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3262 getTileOperandsSegments(), deviceType);
3265std::optional<int64_t> LoopOp::getCollapseValue() {
3266 return getCollapseValue(mlir::acc::DeviceType::None);
3269std::optional<int64_t>
3270LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3271 if (!getCollapseAttr())
3272 return std::nullopt;
3273 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3275 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3276 return intAttr.getValue().getZExtValue();
3278 return std::nullopt;
3281mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3282 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3285mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3286 mlir::acc::DeviceType deviceType) {
3287 if (getGangOperands().empty())
3289 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3290 int32_t nbOperandsBefore = 0;
3291 for (
unsigned i = 0; i < *pos; ++i)
3292 nbOperandsBefore += (*getGangOperandsSegments())[i];
3295 .drop_front(nbOperandsBefore)
3296 .take_front((*getGangOperandsSegments())[*pos]);
3298 int32_t argTypeIdx = nbOperandsBefore;
3299 for (
auto value : values) {
3300 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3301 (*getGangOperandsArgType())[argTypeIdx]);
3302 if (gangArgTypeAttr.getValue() == gangArgType)
3310bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3312bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3317 return {&getRegion()};
3361 if (!regionArgs.empty()) {
3362 p << acc::LoopOp::getControlKeyword() <<
"(";
3363 llvm::interleaveComma(regionArgs, p,
3365 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3366 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3367 <<
" : " << stepType <<
") ";
3374 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3375 effectiveDeviceTypes));
3378void acc::LoopOp::addIndependent(
3380 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3381 context, getIndependentAttr(), effectiveDeviceTypes));
3386 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3387 effectiveDeviceTypes));
3390void acc::LoopOp::setCollapseForDeviceTypes(
3392 llvm::APInt value) {
3396 assert((getCollapseAttr() ==
nullptr) ==
3397 (getCollapseDeviceTypeAttr() ==
nullptr));
3398 assert(value.getBitWidth() == 64);
3400 if (getCollapseAttr()) {
3401 for (
const auto &existing :
3402 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3403 newValues.push_back(std::get<0>(existing));
3404 newDeviceTypes.push_back(std::get<1>(existing));
3408 if (effectiveDeviceTypes.empty()) {
3411 newValues.push_back(
3412 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3413 newDeviceTypes.push_back(
3414 acc::DeviceTypeAttr::get(context, DeviceType::None));
3416 for (DeviceType dt : effectiveDeviceTypes) {
3417 newValues.push_back(
3418 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3419 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3423 setCollapseAttr(ArrayAttr::get(context, newValues));
3424 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3427void acc::LoopOp::setTileForDeviceTypes(
3431 if (getTileOperandsSegments())
3432 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3434 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3435 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3436 getTileOperandsMutable(), segments));
3438 setTileOperandsSegments(segments);
3441void acc::LoopOp::addVectorOperand(
3444 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3445 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3446 newValue, getVectorOperandsMutable()));
3449void acc::LoopOp::addEmptyVector(
3451 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3452 effectiveDeviceTypes));
3455void acc::LoopOp::addWorkerNumOperand(
3458 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3459 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3460 newValue, getWorkerNumOperandsMutable()));
3463void acc::LoopOp::addEmptyWorker(
3465 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3466 effectiveDeviceTypes));
3469void acc::LoopOp::addEmptyGang(
3471 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3472 effectiveDeviceTypes));
3475bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3476 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3477 return attr.getValue() == dt;
3479 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3480 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3483 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3485 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3487 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3493bool acc::LoopOp::hasDefaultGangWorkerVector() {
3494 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3495 hasGang() || getGangValue(GangArgType::Num) ||
3496 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3500acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3501 if (hasSeq(deviceType))
3502 return LoopParMode::loop_seq;
3503 if (hasAuto(deviceType))
3504 return LoopParMode::loop_auto;
3505 if (hasIndependent(deviceType))
3506 return LoopParMode::loop_independent;
3508 return LoopParMode::loop_seq;
3510 return LoopParMode::loop_auto;
3511 assert(hasIndependent() &&
3512 "loop must have default auto, seq, or independent");
3513 return LoopParMode::loop_independent;
3516void acc::LoopOp::addGangOperands(
3521 getGangOperandsSegments())
3522 llvm::copy(*existingSegments, std::back_inserter(segments));
3524 unsigned beforeCount = segments.size();
3526 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3527 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3528 getGangOperandsMutable(), segments));
3530 setGangOperandsSegments(segments);
3537 unsigned numAdded = segments.size() - beforeCount;
3541 if (getGangOperandsArgTypeAttr())
3542 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3544 for (
auto i : llvm::index_range(0u, numAdded)) {
3545 llvm::transform(argTypes, std::back_inserter(gangTypes),
3546 [=](mlir::acc::GangArgType gangTy) {
3547 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3552 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3556void acc::LoopOp::addPrivatization(
MLIRContext *context,
3557 mlir::acc::PrivateOp op,
3558 mlir::acc::PrivateRecipeOp recipe) {
3559 getPrivateOperandsMutable().append(op.getResult());
3563 if (getPrivatizationRecipesAttr())
3564 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
3567 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3568 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3571void acc::LoopOp::addFirstPrivatization(
3572 MLIRContext *context, mlir::acc::FirstprivateOp op,
3573 mlir::acc::FirstprivateRecipeOp recipe) {
3574 getFirstprivateOperandsMutable().append(op.getResult());
3578 if (getFirstprivatizationRecipesAttr())
3579 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
3582 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3583 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3586void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3587 mlir::acc::ReductionRecipeOp recipe) {
3588 getReductionOperandsMutable().append(op.getResult());
3592 if (getReductionRecipesAttr())
3593 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
3596 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3597 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3604LogicalResult acc::DataOp::verify() {
3609 return emitError(
"at least one operand or the default attribute "
3610 "must appear on the data operation");
3612 for (
mlir::Value operand : getDataClauseOperands())
3613 if (isa<BlockArgument>(operand) ||
3614 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3615 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3616 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3617 operand.getDefiningOp()))
3618 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3627unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3629Value DataOp::getDataOperand(
unsigned i) {
3630 unsigned numOptional = getIfCond() ? 1 : 0;
3632 numOptional += getWaitOperands().size();
3633 return getOperand(numOptional + i);
3636bool acc::DataOp::hasAsyncOnly() {
3637 return hasAsyncOnly(mlir::acc::DeviceType::None);
3640bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3645 return getAsyncValue(mlir::acc::DeviceType::None);
3648mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3653bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
3655bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3660 return getWaitValues(mlir::acc::DeviceType::None);
3664DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3666 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3667 getHasWaitDevnum(), deviceType);
3671 return getWaitDevnum(mlir::acc::DeviceType::None);
3674mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3676 getWaitOperandsSegments(), getHasWaitDevnum(),
3680void acc::DataOp::addAsyncOnly(
3682 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3683 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3686void acc::DataOp::addAsyncOperand(
3689 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3690 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3691 getAsyncOperandsMutable()));
3694void acc::DataOp::addWaitOnly(
MLIRContext *context,
3696 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3697 effectiveDeviceTypes));
3700void acc::DataOp::addWaitOperands(
3705 if (getWaitOperandsSegments())
3706 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3708 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3709 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3710 getWaitOperandsMutable(), segments));
3711 setWaitOperandsSegments(segments);
3714 if (getHasWaitDevnumAttr())
3715 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3718 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3720 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3727LogicalResult acc::ExitDataOp::verify() {
3731 if (getDataClauseOperands().empty())
3732 return emitError(
"at least one operand must be present in dataOperands on "
3733 "the exit data operation");
3737 if (getAsyncOperand() && getAsync())
3738 return emitError(
"async attribute cannot appear with asyncOperand");
3742 if (!getWaitOperands().empty() && getWait())
3743 return emitError(
"wait attribute cannot appear with waitOperands");
3745 if (getWaitDevnum() && getWaitOperands().empty())
3746 return emitError(
"wait_devnum cannot appear without waitOperands");
3751unsigned ExitDataOp::getNumDataOperands() {
3752 return getDataClauseOperands().size();
3755Value ExitDataOp::getDataOperand(
unsigned i) {
3756 unsigned numOptional = getIfCond() ? 1 : 0;
3757 numOptional += getAsyncOperand() ? 1 : 0;
3758 numOptional += getWaitDevnum() ? 1 : 0;
3759 return getOperand(getWaitOperands().size() + numOptional + i);
3764 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3767void ExitDataOp::addAsyncOnly(
MLIRContext *context,
3769 assert(effectiveDeviceTypes.empty());
3770 assert(!getAsyncAttr());
3771 assert(!getAsyncOperand());
3773 setAsyncAttr(mlir::UnitAttr::get(context));
3776void ExitDataOp::addAsyncOperand(
3779 assert(effectiveDeviceTypes.empty());
3780 assert(!getAsyncAttr());
3781 assert(!getAsyncOperand());
3783 getAsyncOperandMutable().append(newValue);
3788 assert(effectiveDeviceTypes.empty());
3789 assert(!getWaitAttr());
3790 assert(getWaitOperands().empty());
3791 assert(!getWaitDevnum());
3793 setWaitAttr(mlir::UnitAttr::get(context));
3796void ExitDataOp::addWaitOperands(
3799 assert(effectiveDeviceTypes.empty());
3800 assert(!getWaitAttr());
3801 assert(getWaitOperands().empty());
3802 assert(!getWaitDevnum());
3807 getWaitDevnumMutable().append(newValues.front());
3808 newValues = newValues.drop_front();
3811 getWaitOperandsMutable().append(newValues);
3818LogicalResult acc::EnterDataOp::verify() {
3822 if (getDataClauseOperands().empty())
3823 return emitError(
"at least one operand must be present in dataOperands on "
3824 "the enter data operation");
3828 if (getAsyncOperand() && getAsync())
3829 return emitError(
"async attribute cannot appear with asyncOperand");
3833 if (!getWaitOperands().empty() && getWait())
3834 return emitError(
"wait attribute cannot appear with waitOperands");
3836 if (getWaitDevnum() && getWaitOperands().empty())
3837 return emitError(
"wait_devnum cannot appear without waitOperands");
3839 for (
mlir::Value operand : getDataClauseOperands())
3840 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3841 operand.getDefiningOp()))
3842 return emitError(
"expect data entry operation as defining op");
3847unsigned EnterDataOp::getNumDataOperands() {
3848 return getDataClauseOperands().size();
3851Value EnterDataOp::getDataOperand(
unsigned i) {
3852 unsigned numOptional = getIfCond() ? 1 : 0;
3853 numOptional += getAsyncOperand() ? 1 : 0;
3854 numOptional += getWaitDevnum() ? 1 : 0;
3855 return getOperand(getWaitOperands().size() + numOptional + i);
3860 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
3863void EnterDataOp::addAsyncOnly(
3865 assert(effectiveDeviceTypes.empty());
3866 assert(!getAsyncAttr());
3867 assert(!getAsyncOperand());
3869 setAsyncAttr(mlir::UnitAttr::get(context));
3872void EnterDataOp::addAsyncOperand(
3875 assert(effectiveDeviceTypes.empty());
3876 assert(!getAsyncAttr());
3877 assert(!getAsyncOperand());
3879 getAsyncOperandMutable().append(newValue);
3882void EnterDataOp::addWaitOnly(
MLIRContext *context,
3884 assert(effectiveDeviceTypes.empty());
3885 assert(!getWaitAttr());
3886 assert(getWaitOperands().empty());
3887 assert(!getWaitDevnum());
3889 setWaitAttr(mlir::UnitAttr::get(context));
3892void EnterDataOp::addWaitOperands(
3895 assert(effectiveDeviceTypes.empty());
3896 assert(!getWaitAttr());
3897 assert(getWaitOperands().empty());
3898 assert(!getWaitDevnum());
3903 getWaitDevnumMutable().append(newValues.front());
3904 newValues = newValues.drop_front();
3907 getWaitOperandsMutable().append(newValues);
3914LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
3920LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
3926LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3933 if (
Value writeVal = op.getWriteOpVal()) {
3942LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
3944LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3950AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3951 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3953 return dyn_cast<AtomicReadOp>(getSecondOp());
3956AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3957 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3959 return dyn_cast<AtomicWriteOp>(getSecondOp());
3962AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3963 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3965 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3968LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
3974template <
typename Op>
3977 bool requireAtLeastOneOperand =
true) {
3978 if (operands.empty() && requireAtLeastOneOperand)
3981 "at least one operand must appear on the declare operation");
3984 if (isa<BlockArgument>(operand) ||
3985 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3986 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
3987 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
3988 operand.getDefiningOp()))
3990 "expect valid declare data entry operation or acc.getdeviceptr "
3994 assert(var &&
"declare operands can only be data entry operations which "
3997 std::optional<mlir::acc::DataClause> dataClauseOptional{
3999 assert(dataClauseOptional.has_value() &&
4000 "declare operands can only be data entry operations which must have "
4002 (
void)dataClauseOptional;
4008LogicalResult acc::DeclareEnterOp::verify() {
4016LogicalResult acc::DeclareExitOp::verify() {
4027LogicalResult acc::DeclareOp::verify() {
4036 acc::DeviceType dtype) {
4037 unsigned parallelism = 0;
4038 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4039 parallelism += op.hasWorker(dtype) ? 1 : 0;
4040 parallelism += op.hasVector(dtype) ? 1 : 0;
4041 parallelism += op.hasSeq(dtype) ? 1 : 0;
4045LogicalResult acc::RoutineOp::verify() {
4046 unsigned baseParallelism =
4049 if (baseParallelism > 1)
4050 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4051 "be present at the same time";
4053 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4055 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4056 if (dtype == acc::DeviceType::None)
4060 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4061 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4062 "be present at the same time";
4069 mlir::ArrayAttr &bindIdName,
4070 mlir::ArrayAttr &bindStrName,
4071 mlir::ArrayAttr &deviceIdTypes,
4072 mlir::ArrayAttr &deviceStrTypes) {
4079 mlir::Attribute newAttr;
4080 bool isSymbolRefAttr;
4081 auto parseResult = parser.parseAttribute(newAttr);
4082 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4083 bindIdNameAttrs.push_back(symbolRefAttr);
4084 isSymbolRefAttr = true;
4085 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4086 bindStrNameAttrs.push_back(stringAttr);
4087 isSymbolRefAttr =
false;
4092 if (isSymbolRefAttr) {
4093 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4094 parser.getContext(), mlir::acc::DeviceType::None));
4096 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4097 parser.getContext(), mlir::acc::DeviceType::None));
4100 if (isSymbolRefAttr) {
4101 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4102 parser.parseRSquare())
4105 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4106 parser.parseRSquare())
4114 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4115 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4116 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4117 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4123 std::optional<mlir::ArrayAttr> bindIdName,
4124 std::optional<mlir::ArrayAttr> bindStrName,
4125 std::optional<mlir::ArrayAttr> deviceIdTypes,
4126 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4133 allBindNames.append(bindIdName->begin(), bindIdName->end());
4134 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4139 allBindNames.append(bindStrName->begin(), bindStrName->end());
4140 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4144 if (!allBindNames.empty())
4145 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4146 [&](
const auto &pair) {
4147 p << std::get<0>(pair);
4153 mlir::ArrayAttr &gang,
4154 mlir::ArrayAttr &gangDim,
4155 mlir::ArrayAttr &gangDimDeviceTypes) {
4158 gangDimDeviceTypeAttrs;
4159 bool needCommaBeforeOperands =
false;
4163 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4164 parser.
getContext(), mlir::acc::DeviceType::None));
4165 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4172 if (parser.parseAttribute(gangAttrs.emplace_back()))
4179 needCommaBeforeOperands =
true;
4182 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4186 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4187 parser.parseColon() ||
4188 parser.parseAttribute(gangDimAttrs.emplace_back()))
4190 if (succeeded(parser.parseOptionalLSquare())) {
4191 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4192 parser.parseRSquare())
4195 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4196 parser.getContext(), mlir::acc::DeviceType::None));
4202 if (
failed(parser.parseRParen()))
4205 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4206 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4207 gangDimDeviceTypes =
4208 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4214 std::optional<mlir::ArrayAttr> gang,
4215 std::optional<mlir::ArrayAttr> gangDim,
4216 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4219 gang->size() == 1) {
4220 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4221 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4233 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4234 [&](
const auto &pair) {
4235 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4236 p << std::get<0>(pair);
4244 mlir::ArrayAttr &deviceTypes) {
4248 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4249 parser.
getContext(), mlir::acc::DeviceType::None));
4250 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4257 if (parser.parseAttribute(attributes.emplace_back()))
4265 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4271 std::optional<mlir::ArrayAttr> deviceTypes) {
4274 auto deviceTypeAttr =
4275 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4276 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4285 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4291bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4293bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4297bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4299bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4303bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4305bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4309std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4310RoutineOp::getBindNameValue() {
4311 return getBindNameValue(mlir::acc::DeviceType::None);
4314std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4315RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4318 return std::nullopt;
4321 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4322 auto attr = (*getBindIdName())[*pos];
4323 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4324 assert(symbolRefAttr &&
"expected SymbolRef");
4325 return symbolRefAttr;
4328 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4329 auto attr = (*getBindStrName())[*pos];
4330 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4331 assert(stringAttr &&
"expected String");
4335 return std::nullopt;
4338bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4340bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4344std::optional<int64_t> RoutineOp::getGangDimValue() {
4345 return getGangDimValue(mlir::acc::DeviceType::None);
4348std::optional<int64_t>
4349RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4351 return std::nullopt;
4352 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4353 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4354 return intAttr.getInt();
4356 return std::nullopt;
4363LogicalResult acc::InitOp::verify() {
4367 return emitOpError(
"cannot be nested in a compute operation");
4371void acc::InitOp::addDeviceType(
MLIRContext *context,
4372 mlir::acc::DeviceType deviceType) {
4374 if (getDeviceTypesAttr())
4375 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4377 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4378 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4385LogicalResult acc::ShutdownOp::verify() {
4389 return emitOpError(
"cannot be nested in a compute operation");
4393void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4394 mlir::acc::DeviceType deviceType) {
4396 if (getDeviceTypesAttr())
4397 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4399 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4400 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4407LogicalResult acc::SetOp::verify() {
4411 return emitOpError(
"cannot be nested in a compute operation");
4412 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4413 return emitOpError(
"at least one default_async, device_num, or device_type "
4414 "operand must appear");
4422LogicalResult acc::UpdateOp::verify() {
4424 if (getDataClauseOperands().empty())
4425 return emitError(
"at least one value must be present in dataOperands");
4428 getAsyncOperandsDeviceTypeAttr(),
4433 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4434 getWaitOperandsDeviceTypeAttr(),
"wait")))
4440 for (
mlir::Value operand : getDataClauseOperands())
4441 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4442 operand.getDefiningOp()))
4443 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4449unsigned UpdateOp::getNumDataOperands() {
4450 return getDataClauseOperands().size();
4453Value UpdateOp::getDataOperand(
unsigned i) {
4455 numOptional += getIfCond() ? 1 : 0;
4456 return getOperand(getWaitOperands().size() + numOptional + i);
4461 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
4464bool UpdateOp::hasAsyncOnly() {
4465 return hasAsyncOnly(mlir::acc::DeviceType::None);
4468bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4473 return getAsyncValue(mlir::acc::DeviceType::None);
4476mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4486bool UpdateOp::hasWaitOnly() {
4487 return hasWaitOnly(mlir::acc::DeviceType::None);
4490bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4495 return getWaitValues(mlir::acc::DeviceType::None);
4499UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4501 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4502 getHasWaitDevnum(), deviceType);
4506 return getWaitDevnum(mlir::acc::DeviceType::None);
4509mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4511 getWaitOperandsSegments(), getHasWaitDevnum(),
4517 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4518 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4521void UpdateOp::addAsyncOperand(
4524 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4525 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4526 getAsyncOperandsMutable()));
4531 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4532 effectiveDeviceTypes));
4535void UpdateOp::addWaitOperands(
4540 if (getWaitOperandsSegments())
4541 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4543 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4544 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4545 getWaitOperandsMutable(), segments));
4546 setWaitOperandsSegments(segments);
4549 if (getHasWaitDevnumAttr())
4550 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4553 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4555 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4562LogicalResult acc::WaitOp::verify() {
4565 if (getAsyncOperand() && getAsync())
4566 return emitError(
"async attribute cannot appear with asyncOperand");
4568 if (getWaitDevnum() && getWaitOperands().empty())
4569 return emitError(
"wait_devnum cannot appear without waitOperands");
4574#define GET_OP_CLASSES
4575#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4577#define GET_ATTRDEF_CLASSES
4578#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4580#define GET_TYPEDEF_CLASSES
4581#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4592 .Case<ACC_DATA_ENTRY_OPS>(
4593 [&](
auto entry) {
return entry.getVarPtr(); })
4594 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4595 [&](
auto exit) {
return exit.getVarPtr(); })
4613 [&](
auto entry) {
return entry.getVarType(); })
4614 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4615 [&](
auto exit) {
return exit.getVarType(); })
4625 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4626 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
4636 [&](
auto dataClause) {
return dataClause.getAccVar(); })
4645 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
4655 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4657 dataClause.getBounds().begin(), dataClause.getBounds().end());
4669 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4671 dataClause.getAsyncOperands().begin(),
4672 dataClause.getAsyncOperands().end());
4683 return dataClause.getAsyncOperandsDeviceTypeAttr();
4691 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
4698 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
4705std::optional<mlir::acc::DataClause>
4710 .Case<ACC_DATA_ENTRY_OPS>(
4711 [&](
auto entry) {
return entry.getDataClause(); })
4719 [&](
auto entry) {
return entry.getImplicit(); })
4728 [&](
auto entry) {
return entry.getDataClauseOperands(); })
4730 return dataOperands;
4738 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
4740 return dataOperands;
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
false
Parses a map_entries map type from a string format back into its numeric value.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.