23#include "llvm/ADT/SmallSet.h"
24#include "llvm/ADT/TypeSwitch.h"
25#include "llvm/Support/LogicalResult.h"
31#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
32#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
33#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
35#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
39static bool isScalarLikeType(
Type type) {
47 if (!varName.empty()) {
48 auto varNameAttr = acc::VarNameAttr::get(builder.
getContext(), varName);
54struct MemRefPointerLikeModel
55 :
public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
57 return cast<T>(pointer).getElementType();
60 mlir::acc::VariableTypeCategory
63 if (
auto mappableTy = dyn_cast<MappableType>(varType)) {
64 return mappableTy.getTypeCategory(varPtr);
66 auto memrefTy = cast<T>(pointer);
67 if (!memrefTy.hasRank()) {
70 return mlir::acc::VariableTypeCategory::uncategorized;
73 if (memrefTy.getRank() == 0) {
74 if (isScalarLikeType(memrefTy.getElementType())) {
75 return mlir::acc::VariableTypeCategory::scalar;
79 return mlir::acc::VariableTypeCategory::uncategorized;
83 assert(memrefTy.getRank() > 0 &&
"rank expected to be positive");
84 return mlir::acc::VariableTypeCategory::array;
87 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
88 StringRef varName, Type varType, Value originalVar,
89 bool &needsFree)
const {
90 auto memrefTy = cast<MemRefType>(pointer);
94 if (memrefTy.hasStaticShape()) {
96 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
97 attachVarNameAttr(allocaOp, builder, varName);
98 return allocaOp.getResult();
103 if (originalVar && originalVar.
getType() == memrefTy &&
104 memrefTy.hasRank()) {
105 SmallVector<Value> dynamicSizes;
106 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
107 if (memrefTy.isDynamicDim(i)) {
111 memref::DimOp::create(builder, loc, originalVar, indexValue);
112 dynamicSizes.push_back(dimSize);
119 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
120 attachVarNameAttr(allocOp, builder, varName);
121 return allocOp.getResult();
128 bool genFree(Type pointer, OpBuilder &builder, Location loc,
130 Type varType)
const {
133 Value valueToInspect = allocRes ? allocRes : memrefValue;
136 Value currentValue = valueToInspect;
137 Operation *originalAlloc =
nullptr;
141 while (currentValue) {
144 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
145 originalAlloc = definingOp;
150 if (
auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
151 currentValue = castOp.getSource();
156 if (
auto reinterpretCastOp =
157 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
158 currentValue = reinterpretCastOp.getSource();
170 if (isa<memref::AllocaOp>(originalAlloc)) {
174 if (isa<memref::AllocOp>(originalAlloc)) {
176 memref::DeallocOp::create(builder, loc, memrefValue);
185 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
189 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
190 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
196 if (destMemref && srcMemref &&
197 destMemref.getType().getElementType() ==
198 srcMemref.getType().getElementType() &&
199 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
200 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
208struct LLVMPointerPointerLikeModel
209 :
public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
210 LLVM::LLVMPointerType> {
214struct MemrefAddressOfGlobalModel
215 :
public AddressOfGlobalOpInterface::ExternalModel<
216 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
217 SymbolRefAttr getSymbol(Operation *op)
const {
218 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
219 return getGlobalOp.getNameAttr();
223struct MemrefGlobalVariableModel
224 :
public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
226 bool isConstant(Operation *op)
const {
227 auto globalOp = cast<memref::GlobalOp>(op);
228 return globalOp.getConstant();
236mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
237 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
240 if (existingDeviceTypes)
241 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
243 if (newDeviceTypes.empty())
244 deviceTypes.push_back(
245 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
247 for (DeviceType dt : newDeviceTypes)
248 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
250 return mlir::ArrayAttr::get(context, deviceTypes);
259mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
260 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
265 if (existingDeviceTypes)
266 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
268 if (newDeviceTypes.empty()) {
269 argCollection.
append(arguments);
270 segments.push_back(arguments.size());
271 deviceTypes.push_back(
272 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
275 for (DeviceType dt : newDeviceTypes) {
276 argCollection.
append(arguments);
277 segments.push_back(arguments.size());
278 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
281 return mlir::ArrayAttr::get(context, deviceTypes);
285mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
286 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
290 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
291 newDeviceTypes, arguments,
292 argCollection, segments);
300void OpenACCDialect::initialize() {
303#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
306#define GET_ATTRDEF_LIST
307#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
310#define GET_TYPEDEF_LIST
311#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
317 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
319 UnrankedMemRefType::attachInterface<
320 MemRefPointerLikeModel<UnrankedMemRefType>>(*
getContext());
321 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
325 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
327 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*
getContext());
335 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
339 mlir::acc::DeviceType deviceType) {
343 for (
auto attr : *arrayAttr) {
344 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
345 if (deviceTypeAttr.getValue() == deviceType)
353 std::optional<mlir::ArrayAttr> deviceTypes) {
358 llvm::interleaveComma(*deviceTypes, p,
364 mlir::acc::DeviceType deviceType) {
365 unsigned segmentIdx = 0;
366 for (
auto attr : segments) {
367 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
368 if (deviceTypeAttr.getValue() == deviceType)
369 return std::make_optional(segmentIdx);
379 mlir::acc::DeviceType deviceType) {
381 return range.take_front(0);
382 if (
auto pos =
findSegment(*arrayAttr, deviceType)) {
383 int32_t nbOperandsBefore = 0;
384 for (
unsigned i = 0; i < *pos; ++i)
385 nbOperandsBefore += (*segments)[i];
386 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
388 return range.take_front(0);
395 std::optional<mlir::ArrayAttr> hasWaitDevnum,
396 mlir::acc::DeviceType deviceType) {
399 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType))
400 if (hasWaitDevnum->getValue()[*pos])
411 std::optional<mlir::ArrayAttr> hasWaitDevnum,
412 mlir::acc::DeviceType deviceType) {
417 if (
auto pos =
findSegment(*deviceTypeAttr, deviceType)) {
418 if (hasWaitDevnum && *hasWaitDevnum) {
419 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
420 if (boolAttr.getValue())
421 return range.drop_front(1);
427template <
typename Op>
429 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
431 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
436 op.hasAsyncOnly(dtype))
438 "asyncOnly attribute cannot appear with asyncOperand");
443 op.hasWaitOnly(dtype))
444 return op.
emitError(
"wait attribute cannot appear with waitOperands");
449template <
typename Op>
452 return op.
emitError(
"must have var operand");
455 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
456 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
457 return op.
emitError(
"var must be mappable or pointer-like");
460 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
461 op.getVarType() == op.getVar().getType())
462 return op.
emitError(
"varType must capture the element type of var");
467template <
typename Op>
469 if (op.getVar().getType() != op.getAccVar().getType())
470 return op.
emitError(
"input and output types must match");
475template <
typename Op>
477 if (op.getModifiers() != acc::DataClauseModifier::none)
478 return op.
emitError(
"no data clause modifiers are allowed");
482template <
typename Op>
485 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
487 "invalid data clause modifiers: " +
488 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
510 if (mlir::isa<mlir::acc::PointerLikeType>(var.
getType()))
531 if (failed(parser.
parseType(accVarType)))
541 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.
getType()))
553 mlir::TypeAttr &varTypeAttr) {
554 if (failed(parser.
parseType(varPtrType)))
565 varTypeAttr = mlir::TypeAttr::get(varType);
570 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
571 varTypeAttr = mlir::TypeAttr::get(
572 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).
getElementType());
574 varTypeAttr = mlir::TypeAttr::get(varPtrType);
581 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
589 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
590 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
592 if (typeToCheckAgainst != varType) {
602LogicalResult acc::DataBoundsOp::verify() {
603 auto extent = getExtent();
604 auto upperbound = getUpperbound();
605 if (!extent && !upperbound)
606 return emitError(
"expected extent or upperbound.");
613LogicalResult acc::PrivateOp::verify() {
616 "data clause associated with private operation must match its intent");
627LogicalResult acc::FirstprivateOp::verify() {
629 return emitError(
"data clause associated with firstprivate operation must "
641LogicalResult acc::FirstprivateMapInitialOp::verify() {
643 return emitError(
"data clause associated with firstprivate operation must "
655LogicalResult acc::ReductionOp::verify() {
657 return emitError(
"data clause associated with reduction operation must "
669LogicalResult acc::DevicePtrOp::verify() {
671 return emitError(
"data clause associated with deviceptr operation must "
685LogicalResult acc::PresentOp::verify() {
688 "data clause associated with present operation must match its intent");
701LogicalResult acc::CopyinOp::verify() {
703 if (!getImplicit() &&
getDataClause() != acc::DataClause::acc_copyin &&
708 "data clause associated with copyin operation must match its intent"
709 " or specify original clause this operation was decomposed from");
715 acc::DataClauseModifier::always |
716 acc::DataClauseModifier::capture)))
721bool acc::CopyinOp::isCopyinReadonly() {
722 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
723 acc::bitEnumContainsAny(getModifiers(),
724 acc::DataClauseModifier::readonly);
730LogicalResult acc::CreateOp::verify() {
737 "data clause associated with create operation must match its intent"
738 " or specify original clause this operation was decomposed from");
746 acc::DataClauseModifier::always |
747 acc::DataClauseModifier::capture)))
752bool acc::CreateOp::isCreateZero() {
754 return getDataClause() == acc::DataClause::acc_create_zero ||
756 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
762LogicalResult acc::NoCreateOp::verify() {
764 return emitError(
"data clause associated with no_create operation must "
778LogicalResult acc::AttachOp::verify() {
781 "data clause associated with attach operation must match its intent");
795LogicalResult acc::DeclareDeviceResidentOp::verify() {
796 if (
getDataClause() != acc::DataClause::acc_declare_device_resident)
797 return emitError(
"data clause associated with device_resident operation "
798 "must match its intent");
812LogicalResult acc::DeclareLinkOp::verify() {
815 "data clause associated with link operation must match its intent");
828LogicalResult acc::CopyoutOp::verify() {
835 "data clause associated with copyout operation must match its intent"
836 " or specify original clause this operation was decomposed from");
838 return emitError(
"must have both host and device pointers");
844 acc::DataClauseModifier::always |
845 acc::DataClauseModifier::capture)))
850bool acc::CopyoutOp::isCopyoutZero() {
851 return getDataClause() == acc::DataClause::acc_copyout_zero ||
852 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
858LogicalResult acc::DeleteOp::verify() {
867 getDataClause() != acc::DataClause::acc_declare_device_resident &&
870 "data clause associated with delete operation must match its intent"
871 " or specify original clause this operation was decomposed from");
873 return emitError(
"must have device pointer");
877 acc::DataClauseModifier::readonly |
878 acc::DataClauseModifier::always |
879 acc::DataClauseModifier::capture)))
887LogicalResult acc::DetachOp::verify() {
892 "data clause associated with detach operation must match its intent"
893 " or specify original clause this operation was decomposed from");
895 return emitError(
"must have device pointer");
904LogicalResult acc::UpdateHostOp::verify() {
909 "data clause associated with host operation must match its intent"
910 " or specify original clause this operation was decomposed from");
912 return emitError(
"must have both host and device pointers");
925LogicalResult acc::UpdateDeviceOp::verify() {
929 "data clause associated with device operation must match its intent"
930 " or specify original clause this operation was decomposed from");
943LogicalResult acc::UseDeviceOp::verify() {
947 "data clause associated with use_device operation must match its intent"
948 " or specify original clause this operation was decomposed from");
961LogicalResult acc::CacheOp::verify() {
966 "data clause associated with cache operation must match its intent"
967 " or specify original clause this operation was decomposed from");
977bool acc::CacheOp::isCacheReadonly() {
978 return getDataClause() == acc::DataClause::acc_cache_readonly ||
979 acc::bitEnumContainsAny(getModifiers(),
980 acc::DataClauseModifier::readonly);
983template <
typename StructureOp>
985 unsigned nRegions = 1) {
988 for (
unsigned i = 0; i < nRegions; ++i)
991 for (
Region *region : regions)
999 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1006template <
typename OpTy>
1008 using OpRewritePattern<OpTy>::OpRewritePattern;
1010 LogicalResult matchAndRewrite(OpTy op,
1011 PatternRewriter &rewriter)
const override {
1013 Value ifCond = op.getIfCond();
1017 IntegerAttr constAttr;
1020 if (constAttr.getInt())
1021 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1033 assert(region.
hasOneBlock() &&
"expected single-block region");
1045template <
typename OpTy>
1046struct RemoveConstantIfConditionWithRegion :
public OpRewritePattern<OpTy> {
1047 using OpRewritePattern<OpTy>::OpRewritePattern;
1049 LogicalResult matchAndRewrite(OpTy op,
1050 PatternRewriter &rewriter)
const override {
1052 Value ifCond = op.getIfCond();
1056 IntegerAttr constAttr;
1059 if (constAttr.getInt())
1060 rewriter.
modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1070struct RemoveEmptyKernelEnvironment
1072 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1074 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1075 PatternRewriter &rewriter)
const override {
1076 assert(op->getNumRegions() == 1 &&
"expected op to have one region");
1087 if (
auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1088 for (
auto attr : deviceTypeAttr) {
1089 if (
auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1090 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1097 if (
auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1098 for (
auto attr : hasDevnumAttr) {
1099 if (
auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1100 if (boolAttr.getValue())
1107 if (
auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1108 if (segmentsAttr.size() > 1)
1114 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1141 for (
Value bound : bounds) {
1142 argTypes.push_back(bound.getType());
1143 argLocs.push_back(loc);
1150 Value privatizedValue;
1156 if (isa<MappableType>(varType)) {
1157 auto mappableTy = cast<MappableType>(varType);
1158 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1159 privatizedValue = mappableTy.generatePrivateInit(
1160 builder, loc, typedVar, varName, bounds, {}, needsFree);
1161 if (!privatizedValue)
1164 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1165 auto pointerLikeTy = cast<PointerLikeType>(varType);
1167 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1168 blockArgVar, needsFree);
1169 if (!privatizedValue)
1174 acc::YieldOp::create(builder, loc, privatizedValue);
1189 for (
Value bound : bounds) {
1190 copyArgTypes.push_back(bound.getType());
1191 copyArgLocs.push_back(loc);
1198 bool isMappable = isa<MappableType>(varType);
1199 bool isPointerLike = isa<PointerLikeType>(varType);
1202 if (isMappable && !isPointerLike)
1206 if (isPointerLike) {
1207 auto pointerLikeTy = cast<PointerLikeType>(varType);
1212 if (!pointerLikeTy.genCopy(
1219 acc::TerminatorOp::create(builder, loc);
1233 for (
Value bound : bounds) {
1234 destroyArgTypes.push_back(bound.getType());
1235 destroyArgLocs.push_back(loc);
1239 destroyBlock->
addArguments(destroyArgTypes, destroyArgLocs);
1243 cast<TypedValue<PointerLikeType>>(destroyBlock->
getArgument(1));
1244 if (isa<MappableType>(varType)) {
1245 auto mappableTy = cast<MappableType>(varType);
1246 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
1249 assert(isa<PointerLikeType>(varType) &&
"Expected PointerLikeType");
1250 auto pointerLikeTy = cast<PointerLikeType>(varType);
1251 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1255 acc::TerminatorOp::create(builder, loc);
1266 Operation *op,
Region ®ion, StringRef regionType, StringRef regionName,
1268 if (optional && region.
empty())
1272 return op->
emitOpError() <<
"expects non-empty " << regionName <<
" region";
1276 return op->
emitOpError() <<
"expects " << regionName
1279 << regionType <<
" type";
1282 for (YieldOp yieldOp : region.
getOps<acc::YieldOp>()) {
1283 if (yieldOp.getOperands().size() != 1 ||
1284 yieldOp.getOperands().getTypes()[0] != type)
1285 return op->
emitOpError() <<
"expects " << regionName
1287 "yield a value of the "
1288 << regionType <<
" type";
1294LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1296 "privatization",
"init",
getType(),
1300 *
this, getDestroyRegion(),
"privatization",
"destroy",
getType(),
1306std::optional<PrivateRecipeOp>
1308 StringRef recipeName,
Type varType,
1311 bool isMappable = isa<MappableType>(varType);
1312 bool isPointerLike = isa<PointerLikeType>(varType);
1315 if (!isMappable && !isPointerLike)
1316 return std::nullopt;
1321 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1324 bool needsFree =
false;
1325 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1326 varName, bounds, needsFree))) {
1328 return std::nullopt;
1335 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1336 Value allocRes = yieldOp.getOperand(0);
1338 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1339 varType, allocRes, bounds))) {
1341 return std::nullopt;
1352LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1354 "privatization",
"init",
getType(),
1358 if (getCopyRegion().empty())
1359 return emitOpError() <<
"expects non-empty copy region";
1364 return emitOpError() <<
"expects copy region with two arguments of the "
1365 "privatization type";
1367 if (getDestroyRegion().empty())
1371 "privatization",
"destroy",
1378std::optional<FirstprivateRecipeOp>
1380 StringRef recipeName,
Type varType,
1383 bool isMappable = isa<MappableType>(varType);
1384 bool isPointerLike = isa<PointerLikeType>(varType);
1387 if (!isMappable && !isPointerLike)
1388 return std::nullopt;
1393 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1396 bool needsFree =
false;
1397 if (
failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1398 varName, bounds, needsFree))) {
1400 return std::nullopt;
1404 if (
failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1407 return std::nullopt;
1414 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1415 Value allocRes = yieldOp.getOperand(0);
1417 if (
failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1418 varType, allocRes, bounds))) {
1420 return std::nullopt;
1431LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1437 if (getCombinerRegion().empty())
1438 return emitOpError() <<
"expects non-empty combiner region";
1440 Block &reductionBlock = getCombinerRegion().
front();
1444 return emitOpError() <<
"expects combiner region with the first two "
1445 <<
"arguments of the reduction type";
1447 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1448 if (yieldOp.getOperands().size() != 1 ||
1449 yieldOp.getOperands().getTypes()[0] !=
getType())
1450 return emitOpError() <<
"expects combiner region to yield a value "
1451 "of the reduction type";
1467 if (parser.parseAttribute(attributes.emplace_back()) ||
1468 parser.parseArrow() ||
1469 parser.parseOperand(operands.emplace_back()) ||
1470 parser.parseColonType(types.emplace_back()))
1477 symbols = ArrayAttr::get(parser.
getContext(), arrayAttr);
1484 std::optional<mlir::ArrayAttr> attributes) {
1485 llvm::interleaveComma(llvm::zip(*attributes, operands), p, [&](
auto it) {
1486 p << std::get<0>(it) <<
" -> " << std::get<1>(it) <<
" : "
1487 << std::get<1>(it).getType();
1496template <
typename Op>
1500 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1501 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1502 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1503 operand.getDefiningOp()))
1505 "expect data entry/exit operation or acc.getdeviceptr "
1510template <
typename Op>
1514 llvm::StringRef symbolName,
bool checkOperandType =
true) {
1515 if (!operands.empty()) {
1516 if (!attributes || attributes->size() != operands.size())
1518 <<
"expected as many " << symbolName <<
" symbol reference as "
1519 << operandName <<
" operands";
1523 <<
"unexpected " << symbolName <<
" symbol reference";
1528 for (
auto args : llvm::zip(operands, *attributes)) {
1531 if (!set.insert(operand).second)
1533 << operandName <<
" operand appears more than once";
1536 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1540 <<
"expected symbol reference " << symbolRef <<
" to point to a "
1541 << operandName <<
" declaration";
1543 if (checkOperandType && decl.getType() && decl.getType() != varType)
1544 return op->
emitOpError() <<
"expected " << operandName <<
" (" << varType
1545 <<
") to be the same type as " << operandName
1546 <<
" declaration (" << decl.getType() <<
")";
1552unsigned ParallelOp::getNumDataOperands() {
1553 return getReductionOperands().size() + getPrivateOperands().size() +
1554 getFirstprivateOperands().size() + getDataClauseOperands().size();
1557Value ParallelOp::getDataOperand(
unsigned i) {
1559 numOptional += getNumGangs().size();
1560 numOptional += getNumWorkers().size();
1561 numOptional += getVectorLength().size();
1562 numOptional += getIfCond() ? 1 : 0;
1563 numOptional += getSelfCond() ? 1 : 0;
1564 return getOperand(getWaitOperands().size() + numOptional + i);
1567template <
typename Op>
1570 llvm::StringRef keyword) {
1571 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1572 return op.
emitOpError() << keyword <<
" operands count must match "
1573 << keyword <<
" device_type count";
1577template <
typename Op>
1580 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1581 std::size_t numOperandsInSegments = 0;
1582 std::size_t nbOfSegments = 0;
1585 for (
auto segCount : segments.
asArrayRef()) {
1586 if (maxInSegment != 0 && segCount > maxInSegment)
1587 return op.
emitOpError() << keyword <<
" expects a maximum of "
1588 << maxInSegment <<
" values per segment";
1589 numOperandsInSegments += segCount;
1594 if ((numOperandsInSegments != operands.size()) ||
1595 (!deviceTypes && !operands.empty()))
1597 << keyword <<
" operand count does not match count in segments";
1598 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1600 << keyword <<
" segment count does not match device_type count";
1604LogicalResult acc::ParallelOp::verify() {
1606 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
1607 "privatizations",
false)))
1610 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
1611 "firstprivate",
"firstprivatizations",
false)))
1614 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
1615 "reductions",
false)))
1619 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
1620 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
1624 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1625 getWaitOperandsDeviceTypeAttr(),
"wait")))
1629 getNumWorkersDeviceTypeAttr(),
1634 getVectorLengthDeviceTypeAttr(),
1639 getAsyncOperandsDeviceTypeAttr(),
1652 mlir::acc::DeviceType deviceType) {
1655 if (
auto pos =
findSegment(*arrayAttr, deviceType))
1660bool acc::ParallelOp::hasAsyncOnly() {
1661 return hasAsyncOnly(mlir::acc::DeviceType::None);
1664bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1669 return getAsyncValue(mlir::acc::DeviceType::None);
1672mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1677mlir::Value acc::ParallelOp::getNumWorkersValue() {
1678 return getNumWorkersValue(mlir::acc::DeviceType::None);
1682acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1687mlir::Value acc::ParallelOp::getVectorLengthValue() {
1688 return getVectorLengthValue(mlir::acc::DeviceType::None);
1692acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1694 getVectorLength(), deviceType);
1698 return getNumGangsValues(mlir::acc::DeviceType::None);
1702ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1704 getNumGangsSegments(), deviceType);
1707bool acc::ParallelOp::hasWaitOnly() {
1708 return hasWaitOnly(mlir::acc::DeviceType::None);
1711bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1716 return getWaitValues(mlir::acc::DeviceType::None);
1720ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1722 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1723 getHasWaitDevnum(), deviceType);
1727 return getWaitDevnum(mlir::acc::DeviceType::None);
1730mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1732 getWaitOperandsSegments(), getHasWaitDevnum(),
1748 odsBuilder, odsState, asyncOperands,
nullptr,
1749 nullptr, waitOperands,
nullptr,
1751 nullptr, numGangs,
nullptr,
1752 nullptr, numWorkers,
1753 nullptr, vectorLength,
1754 nullptr, ifCond, selfCond,
1755 nullptr, reductionOperands,
nullptr,
1756 gangPrivateOperands,
nullptr, gangFirstPrivateOperands,
1757 nullptr, dataClauseOperands,
1761void acc::ParallelOp::addNumWorkersOperand(
1764 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1765 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1766 getNumWorkersMutable()));
1768void acc::ParallelOp::addVectorLengthOperand(
1771 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1772 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1773 getVectorLengthMutable()));
1776void acc::ParallelOp::addAsyncOnly(
1778 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1779 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1782void acc::ParallelOp::addAsyncOperand(
1785 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1786 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1787 getAsyncOperandsMutable()));
1790void acc::ParallelOp::addNumGangsOperands(
1794 if (getNumGangsSegments())
1795 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1797 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1798 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1799 getNumGangsMutable(), segments));
1801 setNumGangsSegments(segments);
1803void acc::ParallelOp::addWaitOnly(
1805 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1806 effectiveDeviceTypes));
1808void acc::ParallelOp::addWaitOperands(
1813 if (getWaitOperandsSegments())
1814 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1816 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1817 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1818 getWaitOperandsMutable(), segments));
1819 setWaitOperandsSegments(segments);
1822 if (getHasWaitDevnumAttr())
1823 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1826 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
1828 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1831void acc::ParallelOp::addPrivatization(
MLIRContext *context,
1832 mlir::acc::PrivateOp op,
1833 mlir::acc::PrivateRecipeOp recipe) {
1834 getPrivateOperandsMutable().append(op.getResult());
1838 if (getPrivatizationRecipesAttr())
1839 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
1842 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1843 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1846void acc::ParallelOp::addFirstPrivatization(
1847 MLIRContext *context, mlir::acc::FirstprivateOp op,
1848 mlir::acc::FirstprivateRecipeOp recipe) {
1849 getFirstprivateOperandsMutable().append(op.getResult());
1853 if (getFirstprivatizationRecipesAttr())
1854 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
1857 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1858 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1861void acc::ParallelOp::addReduction(
MLIRContext *context,
1862 mlir::acc::ReductionOp op,
1863 mlir::acc::ReductionRecipeOp recipe) {
1864 getReductionOperandsMutable().append(op.getResult());
1868 if (getReductionRecipesAttr())
1869 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
1872 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
1873 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
1888 int32_t crtOperandsSize = operands.size();
1891 if (parser.parseOperand(operands.emplace_back()) ||
1892 parser.parseColonType(types.emplace_back()))
1897 seg.push_back(operands.size() - crtOperandsSize);
1907 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1908 parser.
getContext(), mlir::acc::DeviceType::None));
1914 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
1921 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
1922 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
1923 p <<
" [" << attr <<
"]";
1928 std::optional<mlir::ArrayAttr> deviceTypes,
1929 std::optional<mlir::DenseI32ArrayAttr> segments) {
1931 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1933 llvm::interleaveComma(
1934 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1935 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
1955 int32_t crtOperandsSize = operands.size();
1959 if (parser.parseOperand(operands.emplace_back()) ||
1960 parser.parseColonType(types.emplace_back()))
1966 seg.push_back(operands.size() - crtOperandsSize);
1976 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
1977 parser.
getContext(), mlir::acc::DeviceType::None));
1983 deviceTypes = ArrayAttr::get(parser.
getContext(), arrayAttr);
1992 std::optional<mlir::DenseI32ArrayAttr> segments) {
1994 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
1996 llvm::interleaveComma(
1997 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
1998 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2011 mlir::ArrayAttr &keywordOnly) {
2015 bool needCommaBeforeOperands =
false;
2019 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2020 parser.
getContext(), mlir::acc::DeviceType::None));
2021 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2028 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2035 needCommaBeforeOperands =
true;
2038 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2045 int32_t crtOperandsSize = operands.size();
2057 if (parser.parseOperand(operands.emplace_back()) ||
2058 parser.parseColonType(types.emplace_back()))
2064 seg.push_back(operands.size() - crtOperandsSize);
2074 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2075 parser.
getContext(), mlir::acc::DeviceType::None));
2082 deviceTypes = ArrayAttr::get(parser.
getContext(), deviceTypeAttrs);
2083 keywordOnly = ArrayAttr::get(parser.
getContext(), keywordAttrs);
2085 hasDevNum = ArrayAttr::get(parser.
getContext(), devnum);
2093 if (attrs->size() != 1)
2095 if (
auto deviceTypeAttr =
2096 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2097 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2103 std::optional<mlir::ArrayAttr> deviceTypes,
2104 std::optional<mlir::DenseI32ArrayAttr> segments,
2105 std::optional<mlir::ArrayAttr> hasDevNum,
2106 std::optional<mlir::ArrayAttr> keywordOnly) {
2119 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2121 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2122 if (boolAttr && boolAttr.getValue())
2124 llvm::interleaveComma(
2125 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2126 p << operands[opIdx] <<
" : " << operands[opIdx].getType();
2143 if (parser.parseOperand(operands.emplace_back()) ||
2144 parser.parseColonType(types.emplace_back()))
2146 if (succeeded(parser.parseOptionalLSquare())) {
2147 if (parser.parseAttribute(attributes.emplace_back()) ||
2148 parser.parseRSquare())
2151 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2152 parser.getContext(), mlir::acc::DeviceType::None));
2159 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2166 std::optional<mlir::ArrayAttr> deviceTypes) {
2169 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](
auto it) {
2170 p << std::get<1>(it) <<
" : " << std::get<1>(it).getType();
2179 mlir::ArrayAttr &keywordOnlyDeviceType) {
2182 bool needCommaBeforeOperands =
false;
2186 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2187 parser.
getContext(), mlir::acc::DeviceType::None));
2188 keywordOnlyDeviceType =
2189 ArrayAttr::get(parser.
getContext(), keywordOnlyDeviceTypeAttributes);
2197 if (parser.parseAttribute(
2198 keywordOnlyDeviceTypeAttributes.emplace_back()))
2205 needCommaBeforeOperands =
true;
2208 if (needCommaBeforeOperands && failed(parser.
parseComma()))
2213 if (parser.parseOperand(operands.emplace_back()) ||
2214 parser.parseColonType(types.emplace_back()))
2216 if (succeeded(parser.parseOptionalLSquare())) {
2217 if (parser.parseAttribute(attributes.emplace_back()) ||
2218 parser.parseRSquare())
2221 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2222 parser.getContext(), mlir::acc::DeviceType::None));
2228 if (
failed(parser.parseRParen()))
2233 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2240 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2242 if (operands.begin() == operands.end() &&
2258 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2259 mlir::Type &operandType, mlir::UnitAttr &attr) {
2262 attr = mlir::UnitAttr::get(parser.
getContext());
2272 if (failed(parser.
parseType(operandType)))
2282 std::optional<mlir::Value> operand,
2284 mlir::UnitAttr attr) {
2301 attr = mlir::UnitAttr::get(parser.
getContext());
2306 if (parser.parseOperand(operands.emplace_back()))
2314 if (parser.parseType(types.emplace_back()))
2329 mlir::UnitAttr attr) {
2334 llvm::interleaveComma(operands, p, [&](
auto it) { p << it; });
2336 llvm::interleaveComma(types, p, [&](
auto it) { p << it; });
2342 mlir::acc::CombinedConstructsTypeAttr &attr) {
2344 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2345 parser.
getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2347 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2348 parser.
getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2350 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2351 parser.
getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2354 "expected compute construct name");
2362 mlir::acc::CombinedConstructsTypeAttr attr) {
2364 switch (attr.getValue()) {
2365 case mlir::acc::CombinedConstructsType::KernelsLoop:
2368 case mlir::acc::CombinedConstructsType::ParallelLoop:
2371 case mlir::acc::CombinedConstructsType::SerialLoop:
2382unsigned SerialOp::getNumDataOperands() {
2383 return getReductionOperands().size() + getPrivateOperands().size() +
2384 getFirstprivateOperands().size() + getDataClauseOperands().size();
2387Value SerialOp::getDataOperand(
unsigned i) {
2389 numOptional += getIfCond() ? 1 : 0;
2390 numOptional += getSelfCond() ? 1 : 0;
2391 return getOperand(getWaitOperands().size() + numOptional + i);
2394bool acc::SerialOp::hasAsyncOnly() {
2395 return hasAsyncOnly(mlir::acc::DeviceType::None);
2398bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2403 return getAsyncValue(mlir::acc::DeviceType::None);
2406mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2411bool acc::SerialOp::hasWaitOnly() {
2412 return hasWaitOnly(mlir::acc::DeviceType::None);
2415bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2420 return getWaitValues(mlir::acc::DeviceType::None);
2424SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2426 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2427 getHasWaitDevnum(), deviceType);
2431 return getWaitDevnum(mlir::acc::DeviceType::None);
2434mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2436 getWaitOperandsSegments(), getHasWaitDevnum(),
2440LogicalResult acc::SerialOp::verify() {
2442 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
2443 "privatizations",
false)))
2446 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
2447 "firstprivate",
"firstprivatizations",
false)))
2450 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
2451 "reductions",
false)))
2455 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2456 getWaitOperandsDeviceTypeAttr(),
"wait")))
2460 getAsyncOperandsDeviceTypeAttr(),
2470void acc::SerialOp::addAsyncOnly(
2472 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2473 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2476void acc::SerialOp::addAsyncOperand(
2479 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2480 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2481 getAsyncOperandsMutable()));
2484void acc::SerialOp::addWaitOnly(
2486 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2487 effectiveDeviceTypes));
2489void acc::SerialOp::addWaitOperands(
2494 if (getWaitOperandsSegments())
2495 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2497 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2498 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2499 getWaitOperandsMutable(), segments));
2500 setWaitOperandsSegments(segments);
2503 if (getHasWaitDevnumAttr())
2504 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2507 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2509 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2512void acc::SerialOp::addPrivatization(
MLIRContext *context,
2513 mlir::acc::PrivateOp op,
2514 mlir::acc::PrivateRecipeOp recipe) {
2515 getPrivateOperandsMutable().append(op.getResult());
2519 if (getPrivatizationRecipesAttr())
2520 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
2523 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2524 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2527void acc::SerialOp::addFirstPrivatization(
2528 MLIRContext *context, mlir::acc::FirstprivateOp op,
2529 mlir::acc::FirstprivateRecipeOp recipe) {
2530 getFirstprivateOperandsMutable().append(op.getResult());
2534 if (getFirstprivatizationRecipesAttr())
2535 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
2538 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2539 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2542void acc::SerialOp::addReduction(
MLIRContext *context,
2543 mlir::acc::ReductionOp op,
2544 mlir::acc::ReductionRecipeOp recipe) {
2545 getReductionOperandsMutable().append(op.getResult());
2549 if (getReductionRecipesAttr())
2550 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
2553 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
2554 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
2561unsigned KernelsOp::getNumDataOperands() {
2562 return getDataClauseOperands().size();
2565Value KernelsOp::getDataOperand(
unsigned i) {
2567 numOptional += getWaitOperands().size();
2568 numOptional += getNumGangs().size();
2569 numOptional += getNumWorkers().size();
2570 numOptional += getVectorLength().size();
2571 numOptional += getIfCond() ? 1 : 0;
2572 numOptional += getSelfCond() ? 1 : 0;
2573 return getOperand(numOptional + i);
2576bool acc::KernelsOp::hasAsyncOnly() {
2577 return hasAsyncOnly(mlir::acc::DeviceType::None);
2580bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2585 return getAsyncValue(mlir::acc::DeviceType::None);
2588mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2594 return getNumWorkersValue(mlir::acc::DeviceType::None);
2598acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2603mlir::Value acc::KernelsOp::getVectorLengthValue() {
2604 return getVectorLengthValue(mlir::acc::DeviceType::None);
2608acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2610 getVectorLength(), deviceType);
2614 return getNumGangsValues(mlir::acc::DeviceType::None);
2618KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2620 getNumGangsSegments(), deviceType);
2623bool acc::KernelsOp::hasWaitOnly() {
2624 return hasWaitOnly(mlir::acc::DeviceType::None);
2627bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2632 return getWaitValues(mlir::acc::DeviceType::None);
2636KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2638 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2639 getHasWaitDevnum(), deviceType);
2643 return getWaitDevnum(mlir::acc::DeviceType::None);
2646mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2648 getWaitOperandsSegments(), getHasWaitDevnum(),
2652LogicalResult acc::KernelsOp::verify() {
2654 *
this, getNumGangs(), getNumGangsSegmentsAttr(),
2655 getNumGangsDeviceTypeAttr(),
"num_gangs", 3)))
2659 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2660 getWaitOperandsDeviceTypeAttr(),
"wait")))
2664 getNumWorkersDeviceTypeAttr(),
2669 getVectorLengthDeviceTypeAttr(),
2674 getAsyncOperandsDeviceTypeAttr(),
2684void acc::KernelsOp::addNumWorkersOperand(
2687 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2688 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2689 getNumWorkersMutable()));
2692void acc::KernelsOp::addVectorLengthOperand(
2695 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2696 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2697 getVectorLengthMutable()));
2699void acc::KernelsOp::addAsyncOnly(
2701 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2702 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2705void acc::KernelsOp::addAsyncOperand(
2708 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2709 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2710 getAsyncOperandsMutable()));
2713void acc::KernelsOp::addNumGangsOperands(
2717 if (getNumGangsSegmentsAttr())
2718 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2720 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2721 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2722 getNumGangsMutable(), segments));
2724 setNumGangsSegments(segments);
2727void acc::KernelsOp::addWaitOnly(
2729 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2730 effectiveDeviceTypes));
2732void acc::KernelsOp::addWaitOperands(
2737 if (getWaitOperandsSegments())
2738 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2740 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2741 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2742 getWaitOperandsMutable(), segments));
2743 setWaitOperandsSegments(segments);
2746 if (getHasWaitDevnumAttr())
2747 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2750 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
2752 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2759LogicalResult acc::HostDataOp::verify() {
2760 if (getDataClauseOperands().empty())
2761 return emitError(
"at least one operand must appear on the host_data "
2764 for (
mlir::Value operand : getDataClauseOperands())
2765 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2766 return emitError(
"expect data entry operation as defining op");
2772 results.
add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2779void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2781 results.
add<RemoveEmptyKernelEnvironment>(context);
2793 bool &needCommaBetweenValues,
bool &newValue) {
2800 attributes.push_back(gangArgType);
2801 needCommaBetweenValues =
true;
2812 mlir::ArrayAttr &gangOnlyDeviceType) {
2817 bool needCommaBetweenValues =
false;
2818 bool needCommaBeforeOperands =
false;
2822 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2823 parser.
getContext(), mlir::acc::DeviceType::None));
2824 gangOnlyDeviceType =
2825 ArrayAttr::get(parser.
getContext(), gangOnlyDeviceTypeAttributes);
2833 if (parser.parseAttribute(
2834 gangOnlyDeviceTypeAttributes.emplace_back()))
2841 needCommaBeforeOperands =
true;
2844 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
2845 mlir::acc::GangArgType::Num);
2846 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.
getContext(),
2847 mlir::acc::GangArgType::Dim);
2848 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2849 parser.
getContext(), mlir::acc::GangArgType::Static);
2852 if (needCommaBeforeOperands) {
2853 needCommaBeforeOperands =
false;
2860 int32_t crtOperandsSize = gangOperands.size();
2862 bool newValue =
false;
2863 bool needValue =
false;
2864 if (needCommaBetweenValues) {
2872 gangOperands, gangOperandsType,
2873 gangArgTypeAttributes, argNum,
2874 needCommaBetweenValues, newValue)))
2877 gangOperands, gangOperandsType,
2878 gangArgTypeAttributes, argDim,
2879 needCommaBetweenValues, newValue)))
2881 if (failed(
parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2882 gangOperands, gangOperandsType,
2883 gangArgTypeAttributes, argStatic,
2884 needCommaBetweenValues, newValue)))
2887 if (!newValue && needValue) {
2889 "new value expected after comma");
2897 if (gangOperands.empty())
2900 "expect at least one of num, dim or static values");
2906 if (parser.
parseAttribute(deviceTypeAttributes.emplace_back()) ||
2910 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2911 parser.
getContext(), mlir::acc::DeviceType::None));
2914 seg.push_back(gangOperands.size() - crtOperandsSize);
2922 gangArgTypeAttributes.end());
2923 gangArgType = ArrayAttr::get(parser.
getContext(), arrayAttr);
2924 deviceType = ArrayAttr::get(parser.
getContext(), deviceTypeAttributes);
2927 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
2928 gangOnlyDeviceType = ArrayAttr::get(parser.
getContext(), gangOnlyAttr);
2936 std::optional<mlir::ArrayAttr> gangArgTypes,
2937 std::optional<mlir::ArrayAttr> deviceTypes,
2938 std::optional<mlir::DenseI32ArrayAttr> segments,
2939 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
2941 if (operands.begin() == operands.end() &&
2956 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](
auto it) {
2958 llvm::interleaveComma(
2959 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](
auto it) {
2960 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
2961 (*gangArgTypes)[opIdx]);
2962 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
2963 p << LoopOp::getGangNumKeyword();
2964 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
2965 p << LoopOp::getGangDimKeyword();
2966 else if (gangArgTypeAttr.getValue() ==
2967 mlir::acc::GangArgType::Static)
2968 p << LoopOp::getGangStaticKeyword();
2969 p <<
"=" << operands[opIdx] <<
" : " << operands[opIdx].getType();
2980 std::optional<mlir::ArrayAttr> segments,
2981 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
2984 for (
auto attr : *segments) {
2985 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2986 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
2994 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
2997 for (
auto attr : deviceTypes) {
2998 auto deviceTypeAttr =
2999 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3000 if (!deviceTypeAttr)
3002 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3008LogicalResult acc::LoopOp::verify() {
3009 if (getUpperbound().size() != getStep().size())
3010 return emitError() <<
"number of upperbounds expected to be the same as "
3013 if (getUpperbound().size() != getLowerbound().size())
3014 return emitError() <<
"number of upperbounds expected to be the same as "
3015 "number of lowerbounds";
3017 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3018 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3019 return emitError() <<
"inclusiveUpperbound size is expected to be the same"
3020 <<
" as upperbound size";
3023 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3024 return emitOpError() <<
"collapse device_type attr must be define when"
3025 <<
" collapse attr is present";
3027 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3028 getCollapseAttr().getValue().size() !=
3029 getCollapseDeviceTypeAttr().getValue().size())
3030 return emitOpError() <<
"collapse attribute count must match collapse"
3031 <<
" device_type count";
3034 <<
"duplicate device_type found in collapseDeviceType attribute";
3037 if (!getGangOperands().empty()) {
3038 if (!getGangOperandsArgType())
3039 return emitOpError() <<
"gangOperandsArgType attribute must be defined"
3040 <<
" when gang operands are present";
3042 if (getGangOperands().size() !=
3043 getGangOperandsArgTypeAttr().getValue().size())
3044 return emitOpError() <<
"gangOperandsArgType attribute count must match"
3045 <<
" gangOperands count";
3048 return emitOpError() <<
"duplicate device_type found in gang attribute";
3051 *
this, getGangOperands(), getGangOperandsSegmentsAttr(),
3052 getGangOperandsDeviceTypeAttr(),
"gang")))
3057 return emitOpError() <<
"duplicate device_type found in worker attribute";
3059 return emitOpError() <<
"duplicate device_type found in "
3060 "workerNumOperandsDeviceType attribute";
3062 getWorkerNumOperandsDeviceTypeAttr(),
3068 return emitOpError() <<
"duplicate device_type found in vector attribute";
3070 return emitOpError() <<
"duplicate device_type found in "
3071 "vectorOperandsDeviceType attribute";
3073 getVectorOperandsDeviceTypeAttr(),
3078 *
this, getTileOperands(), getTileOperandsSegmentsAttr(),
3079 getTileOperandsDeviceTypeAttr(),
"tile")))
3083 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3087 return emitError() <<
"only one of auto, independent, seq can be present "
3093 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) ->
bool {
3094 return attr.getValue() == mlir::acc::DeviceType::None;
3096 bool hasDefaultSeq =
3098 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3101 bool hasDefaultIndependent =
3102 getIndependentAttr()
3104 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3107 bool hasDefaultAuto =
3109 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3112 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3114 <<
"at least one of auto, independent, seq must be present";
3119 for (
auto attr : getSeqAttr()) {
3120 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3121 if (hasVector(deviceTypeAttr.getValue()) ||
3122 getVectorValue(deviceTypeAttr.getValue()) ||
3123 hasWorker(deviceTypeAttr.getValue()) ||
3124 getWorkerValue(deviceTypeAttr.getValue()) ||
3125 hasGang(deviceTypeAttr.getValue()) ||
3126 getGangValue(mlir::acc::GangArgType::Num,
3127 deviceTypeAttr.getValue()) ||
3128 getGangValue(mlir::acc::GangArgType::Dim,
3129 deviceTypeAttr.getValue()) ||
3130 getGangValue(mlir::acc::GangArgType::Static,
3131 deviceTypeAttr.getValue()))
3132 return emitError() <<
"gang, worker or vector cannot appear with seq";
3137 *
this, getPrivatizationRecipes(), getPrivateOperands(),
"private",
3138 "privatizations",
false)))
3142 *
this, getFirstprivatizationRecipes(), getFirstprivateOperands(),
3143 "firstprivate",
"firstprivatizations",
false)))
3147 *
this, getReductionRecipes(), getReductionOperands(),
"reduction",
3148 "reductions",
false)))
3151 if (getCombined().has_value() &&
3152 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3153 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3154 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3155 return emitError(
"unexpected combined constructs attribute");
3159 if (getRegion().empty())
3160 return emitError(
"expected non-empty body.");
3162 if (getUnstructured()) {
3163 if (!isContainerLike())
3165 "unstructured acc.loop must not have induction variables");
3166 }
else if (isContainerLike()) {
3170 uint64_t collapseCount = getCollapseValue().value_or(1);
3171 if (getCollapseAttr()) {
3172 for (
auto collapseEntry : getCollapseAttr()) {
3173 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3174 if (intAttr.getValue().getZExtValue() > collapseCount)
3175 collapseCount = intAttr.getValue().getZExtValue();
3183 bool foundSibling =
false;
3185 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3187 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3189 foundSibling =
true;
3194 expectedParent = op;
3197 if (collapseCount == 0)
3203 return emitError(
"found sibling loops inside container-like acc.loop");
3204 if (collapseCount != 0)
3205 return emitError(
"failed to find enough loop-like operations inside "
3206 "container-like acc.loop");
3212unsigned LoopOp::getNumDataOperands() {
3213 return getReductionOperands().size() + getPrivateOperands().size() +
3214 getFirstprivateOperands().size();
3217Value LoopOp::getDataOperand(
unsigned i) {
3218 unsigned numOptional =
3219 getLowerbound().size() + getUpperbound().size() + getStep().size();
3220 numOptional += getGangOperands().size();
3221 numOptional += getVectorOperands().size();
3222 numOptional += getWorkerNumOperands().size();
3223 numOptional += getTileOperands().size();
3224 numOptional += getCacheOperands().size();
3225 return getOperand(numOptional + i);
3228bool LoopOp::hasAuto() {
return hasAuto(mlir::acc::DeviceType::None); }
3230bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3234bool LoopOp::hasIndependent() {
3235 return hasIndependent(mlir::acc::DeviceType::None);
3238bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3242bool LoopOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
3244bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3249 return getVectorValue(mlir::acc::DeviceType::None);
3252mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3254 getVectorOperands(), deviceType);
3257bool LoopOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
3259bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3264 return getWorkerValue(mlir::acc::DeviceType::None);
3267mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3269 getWorkerNumOperands(), deviceType);
3272bool LoopOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
3274bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3279 return getTileValues(mlir::acc::DeviceType::None);
3283LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3285 getTileOperandsSegments(), deviceType);
3288std::optional<int64_t> LoopOp::getCollapseValue() {
3289 return getCollapseValue(mlir::acc::DeviceType::None);
3292std::optional<int64_t>
3293LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3294 if (!getCollapseAttr())
3295 return std::nullopt;
3296 if (
auto pos =
findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3298 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3299 return intAttr.getValue().getZExtValue();
3301 return std::nullopt;
3304mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3305 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3308mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3309 mlir::acc::DeviceType deviceType) {
3310 if (getGangOperands().empty())
3312 if (
auto pos =
findSegment(*getGangOperandsDeviceType(), deviceType)) {
3313 int32_t nbOperandsBefore = 0;
3314 for (
unsigned i = 0; i < *pos; ++i)
3315 nbOperandsBefore += (*getGangOperandsSegments())[i];
3318 .drop_front(nbOperandsBefore)
3319 .take_front((*getGangOperandsSegments())[*pos]);
3321 int32_t argTypeIdx = nbOperandsBefore;
3322 for (
auto value : values) {
3323 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3324 (*getGangOperandsArgType())[argTypeIdx]);
3325 if (gangArgTypeAttr.getValue() == gangArgType)
3333bool LoopOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
3335bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3340 return {&getRegion()};
3384 if (!regionArgs.empty()) {
3385 p << acc::LoopOp::getControlKeyword() <<
"(";
3386 llvm::interleaveComma(regionArgs, p,
3388 p <<
") = (" << lowerbound <<
" : " << lowerboundType <<
") to ("
3389 << upperbound <<
" : " << upperboundType <<
") " <<
" step (" << steps
3390 <<
" : " << stepType <<
") ";
3397 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3398 effectiveDeviceTypes));
3401void acc::LoopOp::addIndependent(
3403 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3404 context, getIndependentAttr(), effectiveDeviceTypes));
3409 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3410 effectiveDeviceTypes));
3413void acc::LoopOp::setCollapseForDeviceTypes(
3415 llvm::APInt value) {
3419 assert((getCollapseAttr() ==
nullptr) ==
3420 (getCollapseDeviceTypeAttr() ==
nullptr));
3421 assert(value.getBitWidth() == 64);
3423 if (getCollapseAttr()) {
3424 for (
const auto &existing :
3425 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3426 newValues.push_back(std::get<0>(existing));
3427 newDeviceTypes.push_back(std::get<1>(existing));
3431 if (effectiveDeviceTypes.empty()) {
3434 newValues.push_back(
3435 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3436 newDeviceTypes.push_back(
3437 acc::DeviceTypeAttr::get(context, DeviceType::None));
3439 for (DeviceType dt : effectiveDeviceTypes) {
3440 newValues.push_back(
3441 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3442 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3446 setCollapseAttr(ArrayAttr::get(context, newValues));
3447 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3450void acc::LoopOp::setTileForDeviceTypes(
3454 if (getTileOperandsSegments())
3455 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3457 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3458 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3459 getTileOperandsMutable(), segments));
3461 setTileOperandsSegments(segments);
3464void acc::LoopOp::addVectorOperand(
3467 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3468 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3469 newValue, getVectorOperandsMutable()));
3472void acc::LoopOp::addEmptyVector(
3474 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3475 effectiveDeviceTypes));
3478void acc::LoopOp::addWorkerNumOperand(
3481 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3482 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3483 newValue, getWorkerNumOperandsMutable()));
3486void acc::LoopOp::addEmptyWorker(
3488 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3489 effectiveDeviceTypes));
3492void acc::LoopOp::addEmptyGang(
3494 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3495 effectiveDeviceTypes));
3498bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3499 auto hasDevice = [=](DeviceTypeAttr attr) ->
bool {
3500 return attr.getValue() == dt;
3502 auto testFromArr = [=](
ArrayAttr arr) ->
bool {
3503 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3506 if (
ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3508 if (
ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3510 if (
ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3516bool acc::LoopOp::hasDefaultGangWorkerVector() {
3517 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3518 hasGang() || getGangValue(GangArgType::Num) ||
3519 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3523acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3524 if (hasSeq(deviceType))
3525 return LoopParMode::loop_seq;
3526 if (hasAuto(deviceType))
3527 return LoopParMode::loop_auto;
3528 if (hasIndependent(deviceType))
3529 return LoopParMode::loop_independent;
3531 return LoopParMode::loop_seq;
3533 return LoopParMode::loop_auto;
3534 assert(hasIndependent() &&
3535 "loop must have default auto, seq, or independent");
3536 return LoopParMode::loop_independent;
3539void acc::LoopOp::addGangOperands(
3544 getGangOperandsSegments())
3545 llvm::copy(*existingSegments, std::back_inserter(segments));
3547 unsigned beforeCount = segments.size();
3549 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3550 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3551 getGangOperandsMutable(), segments));
3553 setGangOperandsSegments(segments);
3560 unsigned numAdded = segments.size() - beforeCount;
3564 if (getGangOperandsArgTypeAttr())
3565 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3567 for (
auto i : llvm::index_range(0u, numAdded)) {
3568 llvm::transform(argTypes, std::back_inserter(gangTypes),
3569 [=](mlir::acc::GangArgType gangTy) {
3570 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3575 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3579void acc::LoopOp::addPrivatization(
MLIRContext *context,
3580 mlir::acc::PrivateOp op,
3581 mlir::acc::PrivateRecipeOp recipe) {
3582 getPrivateOperandsMutable().append(op.getResult());
3586 if (getPrivatizationRecipesAttr())
3587 llvm::copy(getPrivatizationRecipesAttr(), std::back_inserter(recipes));
3590 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3591 setPrivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3594void acc::LoopOp::addFirstPrivatization(
3595 MLIRContext *context, mlir::acc::FirstprivateOp op,
3596 mlir::acc::FirstprivateRecipeOp recipe) {
3597 getFirstprivateOperandsMutable().append(op.getResult());
3601 if (getFirstprivatizationRecipesAttr())
3602 llvm::copy(getFirstprivatizationRecipesAttr(), std::back_inserter(recipes));
3605 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3606 setFirstprivatizationRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3609void acc::LoopOp::addReduction(
MLIRContext *context, mlir::acc::ReductionOp op,
3610 mlir::acc::ReductionRecipeOp recipe) {
3611 getReductionOperandsMutable().append(op.getResult());
3615 if (getReductionRecipesAttr())
3616 llvm::copy(getReductionRecipesAttr(), std::back_inserter(recipes));
3619 mlir::SymbolRefAttr::get(context, recipe.getSymName().str()));
3620 setReductionRecipesAttr(mlir::ArrayAttr::get(context, recipes));
3627LogicalResult acc::DataOp::verify() {
3632 return emitError(
"at least one operand or the default attribute "
3633 "must appear on the data operation");
3635 for (
mlir::Value operand : getDataClauseOperands())
3636 if (isa<BlockArgument>(operand) ||
3637 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3638 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3639 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3640 operand.getDefiningOp()))
3641 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
3650unsigned DataOp::getNumDataOperands() {
return getDataClauseOperands().size(); }
3652Value DataOp::getDataOperand(
unsigned i) {
3653 unsigned numOptional = getIfCond() ? 1 : 0;
3655 numOptional += getWaitOperands().size();
3656 return getOperand(numOptional + i);
3659bool acc::DataOp::hasAsyncOnly() {
3660 return hasAsyncOnly(mlir::acc::DeviceType::None);
3663bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3668 return getAsyncValue(mlir::acc::DeviceType::None);
3671mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3676bool DataOp::hasWaitOnly() {
return hasWaitOnly(mlir::acc::DeviceType::None); }
3678bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3683 return getWaitValues(mlir::acc::DeviceType::None);
3687DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3689 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3690 getHasWaitDevnum(), deviceType);
3694 return getWaitDevnum(mlir::acc::DeviceType::None);
3697mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3699 getWaitOperandsSegments(), getHasWaitDevnum(),
3703void acc::DataOp::addAsyncOnly(
3705 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3706 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3709void acc::DataOp::addAsyncOperand(
3712 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3713 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3714 getAsyncOperandsMutable()));
3717void acc::DataOp::addWaitOnly(
MLIRContext *context,
3719 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3720 effectiveDeviceTypes));
3723void acc::DataOp::addWaitOperands(
3728 if (getWaitOperandsSegments())
3729 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3731 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3732 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3733 getWaitOperandsMutable(), segments));
3734 setWaitOperandsSegments(segments);
3737 if (getHasWaitDevnumAttr())
3738 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3741 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
3743 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3750LogicalResult acc::ExitDataOp::verify() {
3754 if (getDataClauseOperands().empty())
3755 return emitError(
"at least one operand must be present in dataOperands on "
3756 "the exit data operation");
3760 if (getAsyncOperand() && getAsync())
3761 return emitError(
"async attribute cannot appear with asyncOperand");
3765 if (!getWaitOperands().empty() && getWait())
3766 return emitError(
"wait attribute cannot appear with waitOperands");
3768 if (getWaitDevnum() && getWaitOperands().empty())
3769 return emitError(
"wait_devnum cannot appear without waitOperands");
3774unsigned ExitDataOp::getNumDataOperands() {
3775 return getDataClauseOperands().size();
3778Value ExitDataOp::getDataOperand(
unsigned i) {
3779 unsigned numOptional = getIfCond() ? 1 : 0;
3780 numOptional += getAsyncOperand() ? 1 : 0;
3781 numOptional += getWaitDevnum() ? 1 : 0;
3782 return getOperand(getWaitOperands().size() + numOptional + i);
3787 results.
add<RemoveConstantIfCondition<ExitDataOp>>(context);
3790void ExitDataOp::addAsyncOnly(
MLIRContext *context,
3792 assert(effectiveDeviceTypes.empty());
3793 assert(!getAsyncAttr());
3794 assert(!getAsyncOperand());
3796 setAsyncAttr(mlir::UnitAttr::get(context));
3799void ExitDataOp::addAsyncOperand(
3802 assert(effectiveDeviceTypes.empty());
3803 assert(!getAsyncAttr());
3804 assert(!getAsyncOperand());
3806 getAsyncOperandMutable().append(newValue);
3811 assert(effectiveDeviceTypes.empty());
3812 assert(!getWaitAttr());
3813 assert(getWaitOperands().empty());
3814 assert(!getWaitDevnum());
3816 setWaitAttr(mlir::UnitAttr::get(context));
3819void ExitDataOp::addWaitOperands(
3822 assert(effectiveDeviceTypes.empty());
3823 assert(!getWaitAttr());
3824 assert(getWaitOperands().empty());
3825 assert(!getWaitDevnum());
3830 getWaitDevnumMutable().append(newValues.front());
3831 newValues = newValues.drop_front();
3834 getWaitOperandsMutable().append(newValues);
3841LogicalResult acc::EnterDataOp::verify() {
3845 if (getDataClauseOperands().empty())
3846 return emitError(
"at least one operand must be present in dataOperands on "
3847 "the enter data operation");
3851 if (getAsyncOperand() && getAsync())
3852 return emitError(
"async attribute cannot appear with asyncOperand");
3856 if (!getWaitOperands().empty() && getWait())
3857 return emitError(
"wait attribute cannot appear with waitOperands");
3859 if (getWaitDevnum() && getWaitOperands().empty())
3860 return emitError(
"wait_devnum cannot appear without waitOperands");
3862 for (
mlir::Value operand : getDataClauseOperands())
3863 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3864 operand.getDefiningOp()))
3865 return emitError(
"expect data entry operation as defining op");
3870unsigned EnterDataOp::getNumDataOperands() {
3871 return getDataClauseOperands().size();
3874Value EnterDataOp::getDataOperand(
unsigned i) {
3875 unsigned numOptional = getIfCond() ? 1 : 0;
3876 numOptional += getAsyncOperand() ? 1 : 0;
3877 numOptional += getWaitDevnum() ? 1 : 0;
3878 return getOperand(getWaitOperands().size() + numOptional + i);
3883 results.
add<RemoveConstantIfCondition<EnterDataOp>>(context);
3886void EnterDataOp::addAsyncOnly(
3888 assert(effectiveDeviceTypes.empty());
3889 assert(!getAsyncAttr());
3890 assert(!getAsyncOperand());
3892 setAsyncAttr(mlir::UnitAttr::get(context));
3895void EnterDataOp::addAsyncOperand(
3898 assert(effectiveDeviceTypes.empty());
3899 assert(!getAsyncAttr());
3900 assert(!getAsyncOperand());
3902 getAsyncOperandMutable().append(newValue);
3905void EnterDataOp::addWaitOnly(
MLIRContext *context,
3907 assert(effectiveDeviceTypes.empty());
3908 assert(!getWaitAttr());
3909 assert(getWaitOperands().empty());
3910 assert(!getWaitDevnum());
3912 setWaitAttr(mlir::UnitAttr::get(context));
3915void EnterDataOp::addWaitOperands(
3918 assert(effectiveDeviceTypes.empty());
3919 assert(!getWaitAttr());
3920 assert(getWaitOperands().empty());
3921 assert(!getWaitDevnum());
3926 getWaitDevnumMutable().append(newValues.front());
3927 newValues = newValues.drop_front();
3930 getWaitOperandsMutable().append(newValues);
3937LogicalResult AtomicReadOp::verify() {
return verifyCommon(); }
3943LogicalResult AtomicWriteOp::verify() {
return verifyCommon(); }
3949LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3956 if (
Value writeVal = op.getWriteOpVal()) {
3965LogicalResult AtomicUpdateOp::verify() {
return verifyCommon(); }
3967LogicalResult AtomicUpdateOp::verifyRegions() {
return verifyRegionsCommon(); }
3973AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3974 if (
auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3976 return dyn_cast<AtomicReadOp>(getSecondOp());
3979AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3980 if (
auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3982 return dyn_cast<AtomicWriteOp>(getSecondOp());
3985AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3986 if (
auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3988 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3991LogicalResult AtomicCaptureOp::verifyRegions() {
return verifyRegionsCommon(); }
3997template <
typename Op>
4000 bool requireAtLeastOneOperand =
true) {
4001 if (operands.empty() && requireAtLeastOneOperand)
4004 "at least one operand must appear on the declare operation");
4007 if (isa<BlockArgument>(operand) ||
4008 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4009 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4010 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4011 operand.getDefiningOp()))
4013 "expect valid declare data entry operation or acc.getdeviceptr "
4017 assert(var &&
"declare operands can only be data entry operations which "
4020 std::optional<mlir::acc::DataClause> dataClauseOptional{
4022 assert(dataClauseOptional.has_value() &&
4023 "declare operands can only be data entry operations which must have "
4025 (
void)dataClauseOptional;
4031LogicalResult acc::DeclareEnterOp::verify() {
4039LogicalResult acc::DeclareExitOp::verify() {
4050LogicalResult acc::DeclareOp::verify() {
4059 acc::DeviceType dtype) {
4060 unsigned parallelism = 0;
4061 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4062 parallelism += op.hasWorker(dtype) ? 1 : 0;
4063 parallelism += op.hasVector(dtype) ? 1 : 0;
4064 parallelism += op.hasSeq(dtype) ? 1 : 0;
4068LogicalResult acc::RoutineOp::verify() {
4069 unsigned baseParallelism =
4072 if (baseParallelism > 1)
4073 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4074 "be present at the same time";
4076 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4078 auto dtype =
static_cast<acc::DeviceType
>(dtypeInt);
4079 if (dtype == acc::DeviceType::None)
4083 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4084 return emitError() <<
"only one of `gang`, `worker`, `vector`, `seq` can "
4085 "be present at the same time";
4092 mlir::ArrayAttr &bindIdName,
4093 mlir::ArrayAttr &bindStrName,
4094 mlir::ArrayAttr &deviceIdTypes,
4095 mlir::ArrayAttr &deviceStrTypes) {
4102 mlir::Attribute newAttr;
4103 bool isSymbolRefAttr;
4104 auto parseResult = parser.parseAttribute(newAttr);
4105 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4106 bindIdNameAttrs.push_back(symbolRefAttr);
4107 isSymbolRefAttr = true;
4108 }
else if (
auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4109 bindStrNameAttrs.push_back(stringAttr);
4110 isSymbolRefAttr =
false;
4115 if (isSymbolRefAttr) {
4116 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4117 parser.getContext(), mlir::acc::DeviceType::None));
4119 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4120 parser.getContext(), mlir::acc::DeviceType::None));
4123 if (isSymbolRefAttr) {
4124 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4125 parser.parseRSquare())
4128 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4129 parser.parseRSquare())
4137 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4138 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4139 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4140 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4146 std::optional<mlir::ArrayAttr> bindIdName,
4147 std::optional<mlir::ArrayAttr> bindStrName,
4148 std::optional<mlir::ArrayAttr> deviceIdTypes,
4149 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4156 allBindNames.append(bindIdName->begin(), bindIdName->end());
4157 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4162 allBindNames.append(bindStrName->begin(), bindStrName->end());
4163 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4167 if (!allBindNames.empty())
4168 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4169 [&](
const auto &pair) {
4170 p << std::get<0>(pair);
4176 mlir::ArrayAttr &gang,
4177 mlir::ArrayAttr &gangDim,
4178 mlir::ArrayAttr &gangDimDeviceTypes) {
4181 gangDimDeviceTypeAttrs;
4182 bool needCommaBeforeOperands =
false;
4186 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4187 parser.
getContext(), mlir::acc::DeviceType::None));
4188 gang = ArrayAttr::get(parser.
getContext(), gangAttrs);
4195 if (parser.parseAttribute(gangAttrs.emplace_back()))
4202 needCommaBeforeOperands =
true;
4205 if (needCommaBeforeOperands && failed(parser.
parseComma()))
4209 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4210 parser.parseColon() ||
4211 parser.parseAttribute(gangDimAttrs.emplace_back()))
4213 if (succeeded(parser.parseOptionalLSquare())) {
4214 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4215 parser.parseRSquare())
4218 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4219 parser.getContext(), mlir::acc::DeviceType::None));
4225 if (
failed(parser.parseRParen()))
4228 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4229 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4230 gangDimDeviceTypes =
4231 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4237 std::optional<mlir::ArrayAttr> gang,
4238 std::optional<mlir::ArrayAttr> gangDim,
4239 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4242 gang->size() == 1) {
4243 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4244 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4256 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4257 [&](
const auto &pair) {
4258 p << acc::RoutineOp::getGangDimKeyword() <<
": ";
4259 p << std::get<0>(pair);
4267 mlir::ArrayAttr &deviceTypes) {
4271 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4272 parser.
getContext(), mlir::acc::DeviceType::None));
4273 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4280 if (parser.parseAttribute(attributes.emplace_back()))
4288 deviceTypes = ArrayAttr::get(parser.
getContext(), attributes);
4294 std::optional<mlir::ArrayAttr> deviceTypes) {
4297 auto deviceTypeAttr =
4298 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4299 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4308 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4314bool RoutineOp::hasWorker() {
return hasWorker(mlir::acc::DeviceType::None); }
4316bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4320bool RoutineOp::hasVector() {
return hasVector(mlir::acc::DeviceType::None); }
4322bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4326bool RoutineOp::hasSeq() {
return hasSeq(mlir::acc::DeviceType::None); }
4328bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4332std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4333RoutineOp::getBindNameValue() {
4334 return getBindNameValue(mlir::acc::DeviceType::None);
4337std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4338RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4341 return std::nullopt;
4344 if (
auto pos =
findSegment(*getBindIdNameDeviceType(), deviceType)) {
4345 auto attr = (*getBindIdName())[*pos];
4346 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4347 assert(symbolRefAttr &&
"expected SymbolRef");
4348 return symbolRefAttr;
4351 if (
auto pos =
findSegment(*getBindStrNameDeviceType(), deviceType)) {
4352 auto attr = (*getBindStrName())[*pos];
4353 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4354 assert(stringAttr &&
"expected String");
4358 return std::nullopt;
4361bool RoutineOp::hasGang() {
return hasGang(mlir::acc::DeviceType::None); }
4363bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4367std::optional<int64_t> RoutineOp::getGangDimValue() {
4368 return getGangDimValue(mlir::acc::DeviceType::None);
4371std::optional<int64_t>
4372RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4374 return std::nullopt;
4375 if (
auto pos =
findSegment(*getGangDimDeviceType(), deviceType)) {
4376 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4377 return intAttr.getInt();
4379 return std::nullopt;
4386LogicalResult acc::InitOp::verify() {
4390 return emitOpError(
"cannot be nested in a compute operation");
4394void acc::InitOp::addDeviceType(
MLIRContext *context,
4395 mlir::acc::DeviceType deviceType) {
4397 if (getDeviceTypesAttr())
4398 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4400 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4401 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4408LogicalResult acc::ShutdownOp::verify() {
4412 return emitOpError(
"cannot be nested in a compute operation");
4416void acc::ShutdownOp::addDeviceType(
MLIRContext *context,
4417 mlir::acc::DeviceType deviceType) {
4419 if (getDeviceTypesAttr())
4420 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4422 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4423 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4430LogicalResult acc::SetOp::verify() {
4434 return emitOpError(
"cannot be nested in a compute operation");
4435 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4436 return emitOpError(
"at least one default_async, device_num, or device_type "
4437 "operand must appear");
4445LogicalResult acc::UpdateOp::verify() {
4447 if (getDataClauseOperands().empty())
4448 return emitError(
"at least one value must be present in dataOperands");
4451 getAsyncOperandsDeviceTypeAttr(),
4456 *
this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4457 getWaitOperandsDeviceTypeAttr(),
"wait")))
4463 for (
mlir::Value operand : getDataClauseOperands())
4464 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4465 operand.getDefiningOp()))
4466 return emitError(
"expect data entry/exit operation or acc.getdeviceptr "
4472unsigned UpdateOp::getNumDataOperands() {
4473 return getDataClauseOperands().size();
4476Value UpdateOp::getDataOperand(
unsigned i) {
4478 numOptional += getIfCond() ? 1 : 0;
4479 return getOperand(getWaitOperands().size() + numOptional + i);
4484 results.
add<RemoveConstantIfCondition<UpdateOp>>(context);
4487bool UpdateOp::hasAsyncOnly() {
4488 return hasAsyncOnly(mlir::acc::DeviceType::None);
4491bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4496 return getAsyncValue(mlir::acc::DeviceType::None);
4499mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4509bool UpdateOp::hasWaitOnly() {
4510 return hasWaitOnly(mlir::acc::DeviceType::None);
4513bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4518 return getWaitValues(mlir::acc::DeviceType::None);
4522UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4524 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4525 getHasWaitDevnum(), deviceType);
4529 return getWaitDevnum(mlir::acc::DeviceType::None);
4532mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4534 getWaitOperandsSegments(), getHasWaitDevnum(),
4540 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4541 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4544void UpdateOp::addAsyncOperand(
4547 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4548 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4549 getAsyncOperandsMutable()));
4554 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4555 effectiveDeviceTypes));
4558void UpdateOp::addWaitOperands(
4563 if (getWaitOperandsSegments())
4564 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4566 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4567 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4568 getWaitOperandsMutable(), segments));
4569 setWaitOperandsSegments(segments);
4572 if (getHasWaitDevnumAttr())
4573 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4576 std::max(effectiveDeviceTypes.size(),
static_cast<size_t>(1)),
4578 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4585LogicalResult acc::WaitOp::verify() {
4588 if (getAsyncOperand() && getAsync())
4589 return emitError(
"async attribute cannot appear with asyncOperand");
4591 if (getWaitDevnum() && getWaitOperands().empty())
4592 return emitError(
"wait_devnum cannot appear without waitOperands");
4597#define GET_OP_CLASSES
4598#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4600#define GET_ATTRDEF_CLASSES
4601#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4603#define GET_TYPEDEF_CLASSES
4604#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4615 .Case<ACC_DATA_ENTRY_OPS>(
4616 [&](
auto entry) {
return entry.getVarPtr(); })
4617 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4618 [&](
auto exit) {
return exit.getVarPtr(); })
4636 [&](
auto entry) {
return entry.getVarType(); })
4637 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4638 [&](
auto exit) {
return exit.getVarType(); })
4648 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4649 [&](
auto dataClause) {
return dataClause.getAccPtr(); })
4659 [&](
auto dataClause) {
return dataClause.getAccVar(); })
4668 [&](
auto dataClause) {
return dataClause.getVarPtrPtr(); })
4678 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4680 dataClause.getBounds().begin(), dataClause.getBounds().end());
4692 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](
auto dataClause) {
4694 dataClause.getAsyncOperands().begin(),
4695 dataClause.getAsyncOperands().end());
4706 return dataClause.getAsyncOperandsDeviceTypeAttr();
4714 [&](
auto dataClause) {
return dataClause.getAsyncOnlyAttr(); })
4721 .Case<ACC_DATA_ENTRY_OPS>([&](
auto entry) {
return entry.getName(); })
4728std::optional<mlir::acc::DataClause>
4733 .Case<ACC_DATA_ENTRY_OPS>(
4734 [&](
auto entry) {
return entry.getDataClause(); })
4742 [&](
auto entry) {
return entry.getImplicit(); })
4751 [&](
auto entry) {
return entry.getDataClauseOperands(); })
4753 return dataOperands;
4761 [&](
auto entry) {
return entry.getDataClauseOperandsMutable(); })
4763 return dataOperands;
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region ®ion, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
LogicalResult checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
static bool isComputeOperation(Operation *op)
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
static LogicalResult checkVarAndAccVar(Op op)
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
static LogicalResult checkVarAndVarType(Op op)
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
ParseResult parseLoopControl(OpAsmParser &parser, Region ®ion, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
static LogicalResult checkNoModifier(Op op)
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
void printLoopControl(OpAsmPrinter &p, Operation *op, Region ®ion, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
static void printSymOperandList(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > attributes)
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region ®ion, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
static LogicalResult checkSymOperandList(Operation *op, std::optional< mlir::ArrayAttr > attributes, mlir::OperandRange operands, llvm::StringRef operandName, llvm::StringRef symbolName, bool checkOperandType=true)
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
static LogicalResult checkWaitAndAsyncConflict(Op op)
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
static ParseResult parseSymOperandList(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &symbols)
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
#define ACC_DATA_ENTRY_OPS
#define ACC_DATA_EXIT_OPS
false
Parses a map_entries map type from a string format back into its numeric value.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Block represents an ordered list of Operations.
BlockArgument getArgument(unsigned i)
unsigned getNumArguments()
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Operation * getTerminator()
Get the terminator operation of this block.
BlockArgListType getArguments()
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
MLIRContext is the top-level object for a collection of MLIR operations.
This class provides a mutable adaptor for a range of operands.
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region ®ion, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
This class helps build Operations.
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Operation is the basic unit of execution within MLIR.
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
OperandRange operand_range
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
operand_range getOperands()
Returns an iterator on the underlying Value's.
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
iterator_range< OpIterator > getOps()
bool hasOneBlock()
Return true if this region has exactly one block.
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
This class provides an abstraction over the different types of ranges over Values.
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Type getType() const
Return the type of this value.
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
static WalkResult advance()
static WalkResult interrupt()
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
ArrayRef< T > asArrayRef() const
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
static constexpr StringLiteral getVarNameAttrName()
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.