MLIR 23.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
17#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Support/LLVM.h"
26#include "llvm/ADT/SmallSet.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/LogicalResult.h"
29#include <variant>
30
31using namespace mlir;
32using namespace acc;
33
34#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
38#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
39
40namespace {
41
42static bool isScalarLikeType(Type type) {
43 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
44}
45
46/// Helper function to attach the `VarName` attribute to an operation
47/// if a variable name is provided.
48static void attachVarNameAttr(Operation *op, OpBuilder &builder,
49 StringRef varName) {
50 if (!varName.empty()) {
51 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
52 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
53 }
54}
55
56template <typename T>
57struct MemRefPointerLikeModel
58 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
59 Type getElementType(Type pointer) const {
60 return cast<T>(pointer).getElementType();
61 }
62
63 mlir::acc::VariableTypeCategory
64 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
65 Type varType) const {
66 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
67 return mappableTy.getTypeCategory(varPtr);
68 }
69 auto memrefTy = cast<T>(pointer);
70 if (!memrefTy.hasRank()) {
71 // This memref is unranked - aka it could have any rank, including a
72 // rank of 0 which could mean scalar. For now, return uncategorized.
73 return mlir::acc::VariableTypeCategory::uncategorized;
74 }
75
76 if (memrefTy.getRank() == 0) {
77 if (isScalarLikeType(memrefTy.getElementType())) {
78 return mlir::acc::VariableTypeCategory::scalar;
79 }
80 // Zero-rank non-scalar - need further analysis to determine the type
81 // category. For now, return uncategorized.
82 return mlir::acc::VariableTypeCategory::uncategorized;
83 }
84
85 // It has a rank - must be an array.
86 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
87 return mlir::acc::VariableTypeCategory::array;
88 }
89
90 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
91 StringRef varName, Type varType, Value originalVar,
92 bool &needsFree) const {
93 auto memrefTy = cast<MemRefType>(pointer);
94
95 // Check if this is a static memref (all dimensions are known) - if yes
96 // then we can generate an alloca operation.
97 if (memrefTy.hasStaticShape()) {
98 needsFree = false; // alloca doesn't need deallocation
99 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
100 attachVarNameAttr(allocaOp, builder, varName);
101 return allocaOp.getResult();
102 }
103
104 // For dynamic memrefs, extract sizes from the original variable if
105 // provided. Otherwise they cannot be handled.
106 if (originalVar && originalVar.getType() == memrefTy &&
107 memrefTy.hasRank()) {
108 SmallVector<Value> dynamicSizes;
109 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
110 if (memrefTy.isDynamicDim(i)) {
111 // Extract the size of dimension i from the original variable
112 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
113 auto dimSize =
114 memref::DimOp::create(builder, loc, originalVar, indexValue);
115 dynamicSizes.push_back(dimSize);
116 }
117 // Note: We only add dynamic sizes to the dynamicSizes array
118 // Static dimensions are handled automatically by AllocOp
119 }
120 needsFree = true; // alloc needs deallocation
121 auto allocOp =
122 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
123 attachVarNameAttr(allocOp, builder, varName);
124 return allocOp.getResult();
125 }
126
127 // TODO: Unranked not yet supported.
128 return {};
129 }
130
131 bool genFree(Type pointer, OpBuilder &builder, Location loc,
132 TypedValue<PointerLikeType> varToFree, Value allocRes,
133 Type varType) const {
134 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
135 // Use allocRes if provided to determine the allocation type
136 Value valueToInspect = allocRes ? allocRes : memrefValue;
137
138 // Walk through casts to find the original allocation
139 Value currentValue = valueToInspect;
140 Operation *originalAlloc = nullptr;
141
142 // Follow the chain of operations to find the original allocation
143 // even if a casted result is provided.
144 while (currentValue) {
145 if (auto *definingOp = currentValue.getDefiningOp()) {
146 // Check if this is an allocation operation
147 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
148 originalAlloc = definingOp;
149 break;
150 }
151
152 // Check if this is a cast operation we can look through
153 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
154 currentValue = castOp.getSource();
155 continue;
156 }
157
158 // Check for other cast-like operations
159 if (auto reinterpretCastOp =
160 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
161 currentValue = reinterpretCastOp.getSource();
162 continue;
163 }
164
165 // If we can't look through this operation, stop
166 break;
167 }
168 // This is a block argument or similar - can't trace further.
169 break;
170 }
171
172 if (originalAlloc) {
173 if (isa<memref::AllocaOp>(originalAlloc)) {
174 // This is an alloca - no dealloc needed, but return true (success)
175 return true;
176 }
177 if (isa<memref::AllocOp>(originalAlloc)) {
178 // This is an alloc - generate dealloc on varToFree
179 memref::DeallocOp::create(builder, loc, memrefValue);
180 return true;
181 }
182 }
183 }
184
185 return false;
186 }
187
188 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
189 TypedValue<PointerLikeType> destination,
190 TypedValue<PointerLikeType> source, Type varType) const {
191 // Generate a copy operation between two memrefs
192 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
193 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
194
195 // As per memref documentation, source and destination must have same
196 // element type and shape in order to be compatible. We do not want to fail
197 // with an IR verification error - thus check that before generating the
198 // copy operation.
199 if (destMemref && srcMemref &&
200 destMemref.getType().getElementType() ==
201 srcMemref.getType().getElementType() &&
202 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
203 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
204 return true;
205 }
206
207 return false;
208 }
209
210 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
212 Type valueType) const {
213 // Load from a memref - only valid for scalar memrefs (rank 0).
214 // This is because the address computation for memrefs is part of the load
215 // (and not computed separately), but the API does not have arguments for
216 // indexing.
217 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
218 if (!memrefValue)
219 return {};
220
221 auto memrefTy = memrefValue.getType();
222
223 // Only load from scalar memrefs (rank 0)
224 if (memrefTy.getRank() != 0)
225 return {};
226
227 return memref::LoadOp::create(builder, loc, memrefValue);
228 }
229
230 bool genStore(Type pointer, OpBuilder &builder, Location loc,
231 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
232 // Store to a memref - only valid for scalar memrefs (rank 0)
233 // This is because the address computation for memrefs is part of the store
234 // (and not computed separately), but the API does not have arguments for
235 // indexing.
236 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
237 if (!memrefValue)
238 return false;
239
240 auto memrefTy = memrefValue.getType();
241
242 // Only store to scalar memrefs (rank 0)
243 if (memrefTy.getRank() != 0)
244 return false;
245
246 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
247 return true;
248 }
249
250 Value genCast(Type, OpBuilder &builder, Location loc, Value value,
251 Type resultType) const {
252 if (value.getType() == resultType)
253 return value;
254
255 if (isa<BaseMemRefType>(value.getType()) &&
256 isa<BaseMemRefType>(resultType)) {
257 if (memref::CastOp::areCastCompatible(TypeRange(value.getType()),
258 TypeRange(resultType)))
259 return memref::CastOp::create(builder, loc, resultType, value);
260 if (memref::MemorySpaceCastOp::areCastCompatible(
261 TypeRange(value.getType()), TypeRange(resultType)))
262 return memref::MemorySpaceCastOp::create(builder, loc, resultType,
263 value);
264 }
265
266 // If one side is not a memref, try the other type's `PointerLikeType`
267 // implementation (since it may be an out-of-tree reference type that
268 // we cannot generate here).
269 if (auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
270 if (!isa<BaseMemRefType>(resPtrLike))
271 if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
272 return v;
273 if (auto valPtrLike = dyn_cast<PointerLikeType>(value.getType()))
274 if (!isa<BaseMemRefType>(valPtrLike))
275 if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
276 return v;
277
278 return {};
279 }
280
281 bool isDeviceData(Type pointer, Value var) const {
282 auto memrefTy = cast<T>(pointer);
283 Attribute memSpace = memrefTy.getMemorySpace();
284 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
285 }
286
287 MemRefType getAsMemRefType(Type pointer, ModuleOp module) const {
288 (void)module;
289 return dyn_cast<MemRefType>(pointer);
290 }
291};
292
293struct LLVMPointerPointerLikeModel
294 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
295 LLVM::LLVMPointerType> {
296 Type getElementType(Type pointer) const { return Type(); }
297
298 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
300 Type valueType) const {
301 // For LLVM pointers, we need the valueType to determine what to load
302 if (!valueType)
303 return {};
304
305 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
306 }
307
308 bool genStore(Type pointer, OpBuilder &builder, Location loc,
309 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
310 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
311 return true;
312 }
313
314 Value genCast(Type, OpBuilder &builder, Location loc, Value value,
315 Type resultType) const {
316 if (value.getType() == resultType)
317 return value;
318
319 auto srcPtrTy = dyn_cast<LLVM::LLVMPointerType>(value.getType());
320 auto dstPtrTy = dyn_cast<LLVM::LLVMPointerType>(resultType);
321 if (srcPtrTy && dstPtrTy) {
322 if (srcPtrTy.getAddressSpace() != dstPtrTy.getAddressSpace())
323 return LLVM::AddrSpaceCastOp::create(builder, loc, resultType, value);
324 return value;
325 }
326
327 if (srcPtrTy && isa<IntegerType>(resultType))
328 return LLVM::PtrToIntOp::create(builder, loc, resultType, value);
329
330 if (dstPtrTy) {
331 Value intVal = value;
332 if (isa<IndexType>(value.getType()))
333 intVal = arith::IndexCastUIOp::create(builder, loc,
334 builder.getI64Type(), value);
335 if (isa<IntegerType>(intVal.getType()))
336 return LLVM::IntToPtrOp::create(builder, loc, resultType, intVal);
337 }
338
339 if (auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
340 if (!isa<LLVM::LLVMPointerType>(resPtrLike))
341 if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
342 return v;
343 if (auto valPtrLike = dyn_cast<PointerLikeType>(value.getType()))
344 if (!isa<LLVM::LLVMPointerType>(valPtrLike))
345 if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
346 return v;
347
348 return UnrealizedConversionCastOp::create(builder, loc,
349 TypeRange(resultType), value)
350 .getResult(0);
351 }
352};
353
354struct PrivateTypePointerLikeModel
355 : public PointerLikeType::ExternalModel<PrivateTypePointerLikeModel,
356 PrivateType> {
357 Type getElementType(Type type) const {
358 return cast<PrivateType>(type).getBaseTy();
359 }
360
361 Value genCast(Type, OpBuilder &builder, Location loc, Value value,
362 Type resultType) const {
363 if (value.getType() == resultType)
364 return value;
365 if (!isa<PointerLikeType>(resultType))
366 return {};
367 return UnwrapPrivateOp::create(builder, loc, resultType, value).getResult();
368 }
369
370 MemRefType getAsMemRefType(Type type, ModuleOp module) const {
371 Type baseTy = cast<PrivateType>(type).getBaseTy();
372 if (auto memrefTy = dyn_cast<MemRefType>(baseTy))
373 return memrefTy;
374 if (auto ptrLikeTy = dyn_cast<PointerLikeType>(baseTy))
375 return ptrLikeTy.getAsMemRefType(module);
376 return {};
377 }
378};
379
380struct MemrefAddressOfGlobalModel
381 : public AddressOfGlobalOpInterface::ExternalModel<
382 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
383 SymbolRefAttr getSymbol(Operation *op) const {
384 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
385 return getGlobalOp.getNameAttr();
386 }
387};
388
389struct MemrefGlobalVariableModel
390 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
391 memref::GlobalOp> {
392 bool isConstant(Operation *op) const {
393 auto globalOp = cast<memref::GlobalOp>(op);
394 return globalOp.getConstant();
395 }
396
397 Region *getInitRegion(Operation *op) const {
398 // GlobalOp uses attributes for initialization, not regions
399 return nullptr;
400 }
401
402 bool isDeviceData(Operation *op) const {
403 auto globalOp = cast<memref::GlobalOp>(op);
404 Attribute memSpace = globalOp.getType().getMemorySpace();
405 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
406 }
407};
408
409struct GPULaunchOffloadRegionModel
410 : public acc::OffloadRegionOpInterface::ExternalModel<
411 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
412 mlir::Region &getOffloadRegion(mlir::Operation *op) const {
413 return cast<gpu::LaunchOp>(op).getBody();
414 }
415};
416
417/// Helper function for any of the times we need to modify an ArrayAttr based on
418/// a device type list. Returns a new ArrayAttr with all of the
419/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
420/// list is empty).
421mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
422 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
423 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
425 if (existingDeviceTypes)
426 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
427
428 if (newDeviceTypes.empty())
429 deviceTypes.push_back(
430 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
431
432 for (DeviceType dt : newDeviceTypes)
433 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
434
435 return mlir::ArrayAttr::get(context, deviceTypes);
436}
437
438/// Helper function for any of the times we need to add operands that are
439/// affected by a device type list. Returns a new ArrayAttr with all of the
440/// existingDeviceTypes, plus the effective new ones (or an added none, if the
441/// new list is empty). Additionally, adds the arguments to the argCollection
442/// the correct number of times. This will also update a 'segments' array, even
443/// if it won't be used.
444mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
445 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
446 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
447 mlir::MutableOperandRange argCollection,
448 llvm::SmallVector<int32_t> &segments) {
450 if (existingDeviceTypes)
451 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
452
453 if (newDeviceTypes.empty()) {
454 argCollection.append(arguments);
455 segments.push_back(arguments.size());
456 deviceTypes.push_back(
457 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
458 }
459
460 for (DeviceType dt : newDeviceTypes) {
461 argCollection.append(arguments);
462 segments.push_back(arguments.size());
463 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
464 }
465
466 return mlir::ArrayAttr::get(context, deviceTypes);
467}
468
469/// Overload for when the 'segments' aren't needed.
470mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
471 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
472 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
473 mlir::MutableOperandRange argCollection) {
475 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
476 newDeviceTypes, arguments,
477 argCollection, segments);
478}
479} // namespace
480
481//===----------------------------------------------------------------------===//
482// OpenACC operations
483//===----------------------------------------------------------------------===//
484
485void OpenACCDialect::initialize() {
486 addOperations<
487#define GET_OP_LIST
488#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
489 >();
490 addAttributes<
491#define GET_ATTRDEF_LIST
492#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
493 >();
494 addTypes<
495#define GET_TYPEDEF_LIST
496#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
497 >();
498
499 // By attaching interfaces here, we make the OpenACC dialect dependent on
500 // the other dialects. This is probably better than having dialects like LLVM
501 // and memref be dependent on OpenACC.
502 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
503 *getContext());
504 UnrankedMemRefType::attachInterface<
505 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
506 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
507 *getContext());
508 PrivateType::attachInterface<PrivateTypePointerLikeModel>(*getContext());
509
510 // Attach operation interfaces
511 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
512 *getContext());
513 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
514 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext());
515}
516
517//===----------------------------------------------------------------------===//
518// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
519// acc.kernel_environment / acc.data / acc.host_data / acc.loop
520//===----------------------------------------------------------------------===//
521
522/// Generic helper for single-region OpenACC ops that execute their body once
523/// and then continue after the operation with their results (if any).
524static void
526 RegionBranchPoint point,
528 if (point.isParent()) {
529 regions.push_back(RegionSuccessor(&region));
530 return;
531 }
532
533 regions.push_back(RegionSuccessor(op));
534}
535
537 RegionSuccessor successor) {
538 return successor.isOperation() ? ValueRange(op->getResults()) : ValueRange();
539}
540
541void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
543 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
544 regions);
545}
546
547ValueRange KernelsOp::getSuccessorInputs(RegionSuccessor successor) {
548 return getSingleRegionSuccessorInputs(getOperation(), successor);
549}
550
551void ParallelOp::getSuccessorRegions(
553 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
554 regions);
555}
556
557ValueRange ParallelOp::getSuccessorInputs(RegionSuccessor successor) {
558 return getSingleRegionSuccessorInputs(getOperation(), successor);
559}
560
561void SerialOp::getSuccessorRegions(RegionBranchPoint point,
563 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
564 regions);
565}
566
567ValueRange SerialOp::getSuccessorInputs(RegionSuccessor successor) {
568 return getSingleRegionSuccessorInputs(getOperation(), successor);
569}
570
571void DataOp::getSuccessorRegions(RegionBranchPoint point,
573 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
574 regions);
575}
576
577ValueRange DataOp::getSuccessorInputs(RegionSuccessor successor) {
578 return getSingleRegionSuccessorInputs(getOperation(), successor);
579}
580
581void HostDataOp::getSuccessorRegions(
583 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
584 regions);
585}
586
587ValueRange HostDataOp::getSuccessorInputs(RegionSuccessor successor) {
588 return getSingleRegionSuccessorInputs(getOperation(), successor);
589}
590
591void LoopOp::getSuccessorRegions(RegionBranchPoint point,
593 // Unstructured loops: the body may contain arbitrary CFG and early exits.
594 // At the RegionBranch level, only model entry into the body and exit to the
595 // parent; any backedges are represented inside the region CFG.
596 if (getUnstructured()) {
597 if (point.isParent()) {
598 regions.push_back(RegionSuccessor(&getRegion()));
599 return;
600 }
601 regions.push_back(RegionSuccessor(getOperation()));
602 return;
603 }
604
605 // Structured loops: model a loop-shaped region graph similar to scf.for.
606 regions.push_back(RegionSuccessor(&getRegion()));
607 regions.push_back(RegionSuccessor(getOperation()));
608}
609
610ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) {
611 return getSingleRegionSuccessorInputs(getOperation(), successor);
612}
613
614//===----------------------------------------------------------------------===//
615// RegionBranchTerminatorOpInterface
616//===----------------------------------------------------------------------===//
617
619TerminatorOp::getMutableSuccessorOperands(RegionSuccessor /*point*/) {
620 // `acc.terminator` does not forward operands.
621 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
622}
623
624//===----------------------------------------------------------------------===//
625// device_type support helpers
626//===----------------------------------------------------------------------===//
627
628static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
629 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
630}
631
632static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
633 mlir::acc::DeviceType deviceType) {
634 if (!hasDeviceTypeValues(arrayAttr))
635 return false;
636
637 for (auto attr : *arrayAttr) {
638 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
639 if (deviceTypeAttr.getValue() == deviceType)
640 return true;
641 }
642
643 return false;
644}
645
647 std::optional<mlir::ArrayAttr> deviceTypes) {
648 if (!hasDeviceTypeValues(deviceTypes))
649 return;
650
651 p << "[";
652 llvm::interleaveComma(*deviceTypes, p,
653 [&](mlir::Attribute attr) { p << attr; });
654 p << "]";
655}
656
657static std::optional<unsigned> findSegment(ArrayAttr segments,
658 mlir::acc::DeviceType deviceType) {
659 unsigned segmentIdx = 0;
660 for (auto attr : segments) {
661 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
662 if (deviceTypeAttr.getValue() == deviceType)
663 return std::make_optional(segmentIdx);
664 ++segmentIdx;
665 }
666 return std::nullopt;
667}
668
670getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
672 std::optional<llvm::ArrayRef<int32_t>> segments,
673 mlir::acc::DeviceType deviceType) {
674 if (!arrayAttr)
675 return range.take_front(0);
676 if (auto pos = findSegment(*arrayAttr, deviceType)) {
677 int32_t nbOperandsBefore = 0;
678 for (unsigned i = 0; i < *pos; ++i)
679 nbOperandsBefore += (*segments)[i];
680 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
681 }
682 return range.take_front(0);
683}
684
685static mlir::Value
686getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
688 std::optional<llvm::ArrayRef<int32_t>> segments,
689 std::optional<mlir::ArrayAttr> hasWaitDevnum,
690 mlir::acc::DeviceType deviceType) {
691 if (!hasDeviceTypeValues(deviceTypeAttr))
692 return {};
693 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
694 if (hasWaitDevnum && *hasWaitDevnum) {
695 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
696 if (boolAttr && boolAttr.getValue())
697 return getValuesFromSegments(deviceTypeAttr, operands, segments,
698 deviceType)
699 .front();
700 }
701 }
702 return {};
703}
704
706getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
708 std::optional<llvm::ArrayRef<int32_t>> segments,
709 std::optional<mlir::ArrayAttr> hasWaitDevnum,
710 mlir::acc::DeviceType deviceType) {
711 auto range =
712 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
713 if (range.empty())
714 return range;
715 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
716 if (hasWaitDevnum && *hasWaitDevnum) {
717 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
718 if (boolAttr.getValue())
719 return range.drop_front(1); // first value is devnum
720 }
721 }
722 return range;
723}
724
725template <typename Op>
726static LogicalResult checkWaitAndAsyncConflict(Op op) {
727 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
728 ++dtypeInt) {
729 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
730
731 // The asyncOnly attribute represent the async clause without value.
732 // Therefore the attribute and operand cannot appear at the same time.
733 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
734 op.hasAsyncOnly(dtype))
735 return op.emitError(
736 "asyncOnly attribute cannot appear with asyncOperand");
737
738 // The wait attribute represent the wait clause without values. Therefore
739 // the attribute and operands cannot appear at the same time.
740 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
741 op.hasWaitOnly(dtype))
742 return op.emitError("wait attribute cannot appear with waitOperands");
743 }
744 return success();
745}
746
747template <typename Op>
748static LogicalResult checkVarAndVarType(Op op) {
749 if (!op.getVar())
750 return op.emitError("must have var operand");
751
752 // A variable must have a type that is either pointer-like or mappable.
753 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
754 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
755 return op.emitError("var must be mappable or pointer-like");
756
757 // When it is a pointer-like type, the varType must capture the target type.
758 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
759 op.getVarType() == op.getVar().getType())
760 return op.emitError("varType must capture the element type of var");
761
762 return success();
763}
764
765template <typename Op>
766static LogicalResult checkVarAndAccVar(Op op) {
767 if (op.getVar().getType() != op.getAccVar().getType())
768 return op.emitError("input and output types must match");
769
770 return success();
771}
772
773template <typename Op>
774static LogicalResult checkNoModifier(Op op) {
775 if (op.getModifiers() != acc::DataClauseModifier::none)
776 return op.emitError("no data clause modifiers are allowed");
777 return success();
778}
779
780template <typename Op>
781static LogicalResult
782checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
783 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
784 return op.emitError(
785 "invalid data clause modifiers: " +
786 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
787
788 return success();
789}
790
791template <typename OpT, typename RecipeOpT>
792static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
793 // Mappable types do not need a recipe because it is possible to generate one
794 // from its API. Reject reductions though because no API is available for them
795 // at this time.
796 if (mlir::acc::isMappableType(op.getVar().getType()) &&
797 !std::is_same_v<OpT, acc::ReductionOp>)
798 return success();
799
800 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
801 if (!operandRecipe)
802 return op->emitOpError() << "recipe expected for " << operandName;
803
804 auto decl =
806 if (!decl)
807 return op->emitOpError()
808 << "expected symbol reference " << operandRecipe << " to point to a "
809 << operandName << " declaration";
810 return success();
811}
812
813static ParseResult parseVar(mlir::OpAsmParser &parser,
815 // Either `var` or `varPtr` keyword is required.
816 if (failed(parser.parseOptionalKeyword("varPtr"))) {
817 if (failed(parser.parseKeyword("var")))
818 return failure();
819 }
820 if (failed(parser.parseLParen()))
821 return failure();
822 if (failed(parser.parseOperand(var)))
823 return failure();
824
825 return success();
826}
827
829 mlir::Value var) {
830 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
831 p << "varPtr(";
832 else
833 p << "var(";
834 p.printOperand(var);
835}
836
837static ParseResult parseAccVar(mlir::OpAsmParser &parser,
839 mlir::Type &accVarType) {
840 // Either `accVar` or `accPtr` keyword is required.
841 if (failed(parser.parseOptionalKeyword("accPtr"))) {
842 if (failed(parser.parseKeyword("accVar")))
843 return failure();
844 }
845 if (failed(parser.parseLParen()))
846 return failure();
847 if (failed(parser.parseOperand(var)))
848 return failure();
849 if (failed(parser.parseColon()))
850 return failure();
851 if (failed(parser.parseType(accVarType)))
852 return failure();
853 if (failed(parser.parseRParen()))
854 return failure();
855
856 return success();
857}
858
860 mlir::Value accVar, mlir::Type accVarType) {
861 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
862 p << "accPtr(";
863 else
864 p << "accVar(";
865 p.printOperand(accVar);
866 p << " : ";
867 p.printType(accVarType);
868 p << ")";
869}
870
871static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
872 mlir::Type &varPtrType,
873 mlir::TypeAttr &varTypeAttr) {
874 if (failed(parser.parseType(varPtrType)))
875 return failure();
876 if (failed(parser.parseRParen()))
877 return failure();
878
879 if (succeeded(parser.parseOptionalKeyword("varType"))) {
880 if (failed(parser.parseLParen()))
881 return failure();
882 mlir::Type varType;
883 if (failed(parser.parseType(varType)))
884 return failure();
885 varTypeAttr = mlir::TypeAttr::get(varType);
886 if (failed(parser.parseRParen()))
887 return failure();
888 } else {
889 // Set `varType` from the element type of the type of `varPtr`.
890 if (auto ptrTy = dyn_cast<acc::PointerLikeType>(varPtrType)) {
891 Type elementType = ptrTy.getElementType();
892 // Opaque pointers (e.g. !llvm.ptr) have no element type; fall back to
893 // using varPtrType itself so that the attribute is always valid.
894 varTypeAttr = mlir::TypeAttr::get(elementType ? elementType : varPtrType);
895 } else {
896 varTypeAttr = mlir::TypeAttr::get(varPtrType);
897 }
898 }
899
900 return success();
901}
902
904 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
905 p.printType(varPtrType);
906 p << ")";
907
908 // Print the `varType` only if it differs from the element type of
909 // `varPtr`'s type.
910 mlir::Type varType = varTypeAttr.getValue();
911 mlir::Type typeToCheckAgainst =
912 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
913 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
914 : varPtrType;
915 // Opaque pointers (e.g. !llvm.ptr) have no element type; use varPtrType as
916 // the baseline so that the inferred varType is not redundantly printed.
917 if (!typeToCheckAgainst)
918 typeToCheckAgainst = varPtrType;
919 if (typeToCheckAgainst != varType) {
920 p << " varType(";
921 p.printType(varType);
922 p << ")";
923 }
924}
925
926static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
927 mlir::SymbolRefAttr &recipeAttr) {
928 if (failed(parser.parseAttribute(recipeAttr)))
929 return failure();
930 return success();
931}
932
934 mlir::SymbolRefAttr recipeAttr) {
935 p << recipeAttr;
936}
937
938//===----------------------------------------------------------------------===//
939// DataBoundsOp
940//===----------------------------------------------------------------------===//
941LogicalResult acc::DataBoundsOp::verify() {
942 auto extent = getExtent();
943 auto upperbound = getUpperbound();
944 if (!extent && !upperbound)
945 return emitError("expected extent or upperbound.");
946 return success();
947}
948
949//===----------------------------------------------------------------------===//
950// PrivateOp
951//===----------------------------------------------------------------------===//
952LogicalResult acc::PrivateOp::verify() {
953 if (getDataClause() != acc::DataClause::acc_private)
954 return emitError(
955 "data clause associated with private operation must match its intent");
956 if (failed(checkVarAndVarType(*this)))
957 return failure();
958 if (failed(checkNoModifier(*this)))
959 return failure();
960 if (failed(
962 return failure();
963 return success();
964}
965
966//===----------------------------------------------------------------------===//
967// FirstprivateOp
968//===----------------------------------------------------------------------===//
969LogicalResult acc::FirstprivateOp::verify() {
970 if (getDataClause() != acc::DataClause::acc_firstprivate)
971 return emitError("data clause associated with firstprivate operation must "
972 "match its intent");
973 if (failed(checkVarAndVarType(*this)))
974 return failure();
975 if (failed(checkNoModifier(*this)))
976 return failure();
978 *this, "firstprivate")))
979 return failure();
980 return success();
981}
982
983//===----------------------------------------------------------------------===//
984// ReductionOp
985//===----------------------------------------------------------------------===//
986LogicalResult acc::ReductionOp::verify() {
987 if (getDataClause() != acc::DataClause::acc_reduction)
988 return emitError("data clause associated with reduction operation must "
989 "match its intent");
990 if (failed(checkVarAndVarType(*this)))
991 return failure();
992 if (failed(checkNoModifier(*this)))
993 return failure();
995 *this, "reduction")))
996 return failure();
997 return success();
998}
999
1000//===----------------------------------------------------------------------===//
1001// DevicePtrOp
1002//===----------------------------------------------------------------------===//
1003LogicalResult acc::DevicePtrOp::verify() {
1004 if (getDataClause() != acc::DataClause::acc_deviceptr)
1005 return emitError("data clause associated with deviceptr operation must "
1006 "match its intent");
1007 if (failed(checkVarAndVarType(*this)))
1008 return failure();
1009 if (failed(checkVarAndAccVar(*this)))
1010 return failure();
1011 if (failed(checkNoModifier(*this)))
1012 return failure();
1013 return success();
1014}
1015
1016//===----------------------------------------------------------------------===//
1017// PresentOp
1018//===----------------------------------------------------------------------===//
1019LogicalResult acc::PresentOp::verify() {
1020 if (getDataClause() != acc::DataClause::acc_present)
1021 return emitError(
1022 "data clause associated with present operation must match its intent");
1023 if (failed(checkVarAndVarType(*this)))
1024 return failure();
1025 if (failed(checkVarAndAccVar(*this)))
1026 return failure();
1027 if (failed(checkNoModifier(*this)))
1028 return failure();
1029 return success();
1030}
1031
1032//===----------------------------------------------------------------------===//
1033// CopyinOp
1034//===----------------------------------------------------------------------===//
1035LogicalResult acc::CopyinOp::verify() {
1036 // Test for all clauses this operation can be decomposed from:
1037 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
1038 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1039 getDataClause() != acc::DataClause::acc_copy &&
1040 getDataClause() != acc::DataClause::acc_reduction)
1041 return emitError(
1042 "data clause associated with copyin operation must match its intent"
1043 " or specify original clause this operation was decomposed from");
1044 if (failed(checkVarAndVarType(*this)))
1045 return failure();
1046 if (failed(checkVarAndAccVar(*this)))
1047 return failure();
1048 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
1049 acc::DataClauseModifier::always |
1050 acc::DataClauseModifier::capture)))
1051 return failure();
1052 return success();
1053}
1054
1055bool acc::CopyinOp::isCopyinReadonly() {
1056 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
1057 acc::bitEnumContainsAny(getModifiers(),
1058 acc::DataClauseModifier::readonly);
1059}
1060
1061//===----------------------------------------------------------------------===//
1062// CreateOp
1063//===----------------------------------------------------------------------===//
1064LogicalResult acc::CreateOp::verify() {
1065 // Test for all clauses this operation can be decomposed from:
1066 if (getDataClause() != acc::DataClause::acc_create &&
1067 getDataClause() != acc::DataClause::acc_create_zero &&
1068 getDataClause() != acc::DataClause::acc_copyout &&
1069 getDataClause() != acc::DataClause::acc_copyout_zero)
1070 return emitError(
1071 "data clause associated with create operation must match its intent"
1072 " or specify original clause this operation was decomposed from");
1073 if (failed(checkVarAndVarType(*this)))
1074 return failure();
1075 if (failed(checkVarAndAccVar(*this)))
1076 return failure();
1077 // this op is the entry part of copyout, so it also needs to allow all
1078 // modifiers allowed on copyout.
1079 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1080 acc::DataClauseModifier::always |
1081 acc::DataClauseModifier::capture)))
1082 return failure();
1083 return success();
1084}
1085
1086bool acc::CreateOp::isCreateZero() {
1087 // The zero modifier is encoded in the data clause.
1088 return getDataClause() == acc::DataClause::acc_create_zero ||
1089 getDataClause() == acc::DataClause::acc_copyout_zero ||
1090 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1091}
1092
1093//===----------------------------------------------------------------------===//
1094// NoCreateOp
1095//===----------------------------------------------------------------------===//
1096LogicalResult acc::NoCreateOp::verify() {
1097 if (getDataClause() != acc::DataClause::acc_no_create)
1098 return emitError("data clause associated with no_create operation must "
1099 "match its intent");
1100 if (failed(checkVarAndVarType(*this)))
1101 return failure();
1102 if (failed(checkVarAndAccVar(*this)))
1103 return failure();
1104 if (failed(checkNoModifier(*this)))
1105 return failure();
1106 return success();
1107}
1108
1109//===----------------------------------------------------------------------===//
1110// AttachOp
1111//===----------------------------------------------------------------------===//
1112LogicalResult acc::AttachOp::verify() {
1113 if (getDataClause() != acc::DataClause::acc_attach)
1114 return emitError(
1115 "data clause associated with attach operation must match its intent");
1116 if (failed(checkVarAndVarType(*this)))
1117 return failure();
1118 if (failed(checkVarAndAccVar(*this)))
1119 return failure();
1120 if (failed(checkNoModifier(*this)))
1121 return failure();
1122 return success();
1123}
1124
1125//===----------------------------------------------------------------------===//
1126// DeclareDeviceResidentOp
1127//===----------------------------------------------------------------------===//
1128
1129LogicalResult acc::DeclareDeviceResidentOp::verify() {
1130 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
1131 return emitError("data clause associated with device_resident operation "
1132 "must match its intent");
1133 if (failed(checkVarAndVarType(*this)))
1134 return failure();
1135 if (failed(checkVarAndAccVar(*this)))
1136 return failure();
1137 if (failed(checkNoModifier(*this)))
1138 return failure();
1139 return success();
1140}
1141
1142//===----------------------------------------------------------------------===//
1143// DeclareLinkOp
1144//===----------------------------------------------------------------------===//
1145
1146LogicalResult acc::DeclareLinkOp::verify() {
1147 if (getDataClause() != acc::DataClause::acc_declare_link)
1148 return emitError(
1149 "data clause associated with link operation must match its intent");
1150 if (failed(checkVarAndVarType(*this)))
1151 return failure();
1152 if (failed(checkVarAndAccVar(*this)))
1153 return failure();
1154 if (failed(checkNoModifier(*this)))
1155 return failure();
1156 return success();
1157}
1158
1159//===----------------------------------------------------------------------===//
1160// CopyoutOp
1161//===----------------------------------------------------------------------===//
1162LogicalResult acc::CopyoutOp::verify() {
1163 // Test for all clauses this operation can be decomposed from:
1164 if (getDataClause() != acc::DataClause::acc_copyout &&
1165 getDataClause() != acc::DataClause::acc_copyout_zero &&
1166 getDataClause() != acc::DataClause::acc_copy &&
1167 getDataClause() != acc::DataClause::acc_reduction)
1168 return emitError(
1169 "data clause associated with copyout operation must match its intent"
1170 " or specify original clause this operation was decomposed from");
1171 if (!getVar() || !getAccVar())
1172 return emitError("must have both host and device pointers");
1173 if (failed(checkVarAndVarType(*this)))
1174 return failure();
1175 if (failed(checkVarAndAccVar(*this)))
1176 return failure();
1177 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1178 acc::DataClauseModifier::always |
1179 acc::DataClauseModifier::capture)))
1180 return failure();
1181 return success();
1182}
1183
1184bool acc::CopyoutOp::isCopyoutZero() {
1185 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1186 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1187}
1188
1189//===----------------------------------------------------------------------===//
1190// DeleteOp
1191//===----------------------------------------------------------------------===//
1192LogicalResult acc::DeleteOp::verify() {
1193 // Test for all clauses this operation can be decomposed from:
1194 if (getDataClause() != acc::DataClause::acc_delete &&
1195 getDataClause() != acc::DataClause::acc_create &&
1196 getDataClause() != acc::DataClause::acc_create_zero &&
1197 getDataClause() != acc::DataClause::acc_copyin &&
1198 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1199 getDataClause() != acc::DataClause::acc_present &&
1200 getDataClause() != acc::DataClause::acc_no_create &&
1201 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1202 getDataClause() != acc::DataClause::acc_declare_link)
1203 return emitError(
1204 "data clause associated with delete operation must match its intent"
1205 " or specify original clause this operation was decomposed from");
1206 if (!getAccVar())
1207 return emitError("must have device pointer");
1208 // This op is the exit part of copyin and create - thus allow all modifiers
1209 // allowed on either case.
1210 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1211 acc::DataClauseModifier::readonly |
1212 acc::DataClauseModifier::always |
1213 acc::DataClauseModifier::capture)))
1214 return failure();
1215 return success();
1216}
1217
1218//===----------------------------------------------------------------------===//
1219// DetachOp
1220//===----------------------------------------------------------------------===//
1221LogicalResult acc::DetachOp::verify() {
1222 // Test for all clauses this operation can be decomposed from:
1223 if (getDataClause() != acc::DataClause::acc_detach &&
1224 getDataClause() != acc::DataClause::acc_attach)
1225 return emitError(
1226 "data clause associated with detach operation must match its intent"
1227 " or specify original clause this operation was decomposed from");
1228 if (!getAccVar())
1229 return emitError("must have device pointer");
1230 if (failed(checkNoModifier(*this)))
1231 return failure();
1232 return success();
1233}
1234
1235//===----------------------------------------------------------------------===//
1236// HostOp
1237//===----------------------------------------------------------------------===//
1238LogicalResult acc::UpdateHostOp::verify() {
1239 // Test for all clauses this operation can be decomposed from:
1240 if (getDataClause() != acc::DataClause::acc_update_host &&
1241 getDataClause() != acc::DataClause::acc_update_self)
1242 return emitError(
1243 "data clause associated with host operation must match its intent"
1244 " or specify original clause this operation was decomposed from");
1245 if (!getVar() || !getAccVar())
1246 return emitError("must have both host and device pointers");
1247 if (failed(checkVarAndVarType(*this)))
1248 return failure();
1249 if (failed(checkVarAndAccVar(*this)))
1250 return failure();
1251 if (failed(checkNoModifier(*this)))
1252 return failure();
1253 return success();
1254}
1255
1256//===----------------------------------------------------------------------===//
1257// DeviceOp
1258//===----------------------------------------------------------------------===//
1259LogicalResult acc::UpdateDeviceOp::verify() {
1260 // Test for all clauses this operation can be decomposed from:
1261 if (getDataClause() != acc::DataClause::acc_update_device)
1262 return emitError(
1263 "data clause associated with device operation must match its intent"
1264 " or specify original clause this operation was decomposed from");
1265 if (failed(checkVarAndVarType(*this)))
1266 return failure();
1267 if (failed(checkVarAndAccVar(*this)))
1268 return failure();
1269 if (failed(checkNoModifier(*this)))
1270 return failure();
1271 return success();
1272}
1273
1274//===----------------------------------------------------------------------===//
1275// UseDeviceOp
1276//===----------------------------------------------------------------------===//
1277LogicalResult acc::UseDeviceOp::verify() {
1278 // Test for all clauses this operation can be decomposed from:
1279 if (getDataClause() != acc::DataClause::acc_use_device)
1280 return emitError(
1281 "data clause associated with use_device operation must match its intent"
1282 " or specify original clause this operation was decomposed from");
1283 if (failed(checkVarAndVarType(*this)))
1284 return failure();
1285 if (failed(checkVarAndAccVar(*this)))
1286 return failure();
1287 if (failed(checkNoModifier(*this)))
1288 return failure();
1289 return success();
1290}
1291
1292//===----------------------------------------------------------------------===//
1293// CacheOp
1294//===----------------------------------------------------------------------===//
1295LogicalResult acc::CacheOp::verify() {
1296 // Test for all clauses this operation can be decomposed from:
1297 if (getDataClause() != acc::DataClause::acc_cache &&
1298 getDataClause() != acc::DataClause::acc_cache_readonly)
1299 return emitError(
1300 "data clause associated with cache operation must match its intent"
1301 " or specify original clause this operation was decomposed from");
1302 if (failed(checkVarAndVarType(*this)))
1303 return failure();
1304 if (failed(checkVarAndAccVar(*this)))
1305 return failure();
1306 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
1307 return failure();
1308 return success();
1309}
1310
1311bool acc::CacheOp::isCacheReadonly() {
1312 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1313 acc::bitEnumContainsAny(getModifiers(),
1314 acc::DataClauseModifier::readonly);
1315}
1316
1317//===----------------------------------------------------------------------===//
1318// Data entry/exit operations - getEffects implementations
1319//===----------------------------------------------------------------------===//
1320
1321// This function returns true iff the given operation is enclosed
1322// in any ACC_COMPUTE_CONSTRUCT_OPS operation.
1323// It is quite alike acc::getEnclosingComputeOp() utility,
1324// but we cannot use it here.
1328
1329/// Helper to add an effect on an operand, referenced by its mutable range.
1330template <typename EffectTy>
1333 &effects,
1334 MutableOperandRange operand) {
1335 for (unsigned i = 0, e = operand.size(); i < e; ++i)
1336 effects.emplace_back(EffectTy::get(), &operand[i]);
1337}
1338
1339/// Helper to add an effect on a result value.
1340template <typename EffectTy>
1343 &effects,
1344 Value result) {
1345 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
1346}
1347
1348// PrivateOp: accVar result write.
1349void acc::PrivateOp::getEffects(
1351 &effects) {
1352 // If acc.private is enclosed into a compute operation,
1353 // then it denotes the device side privatization, hence
1354 // it does not access the CurrentDeviceIdResource.
1355 if (!isEnclosedIntoComputeOp(getOperation()))
1356 effects.emplace_back(MemoryEffects::Read::get(),
1358 // TODO: should this be MemoryEffects::Allocate?
1360}
1361
1362// FirstprivateOp: var read, accVar result write.
1363void acc::FirstprivateOp::getEffects(
1365 &effects) {
1366 // If acc.firstprivate is enclosed into a compute operation,
1367 // then it denotes the device side privatization, hence
1368 // it does not access the CurrentDeviceIdResource.
1369 if (!isEnclosedIntoComputeOp(getOperation()))
1370 effects.emplace_back(MemoryEffects::Read::get(),
1372 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1374}
1375
1376// ReductionOp: var read, accVar result write.
1377void acc::ReductionOp::getEffects(
1379 &effects) {
1380 // If acc.reduction is enclosed into a compute operation,
1381 // then it denotes the device side reduction, hence
1382 // it does not access the CurrentDeviceIdResource.
1383 if (!isEnclosedIntoComputeOp(getOperation()))
1384 effects.emplace_back(MemoryEffects::Read::get(),
1386 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1388}
1389
1390// DevicePtrOp: RuntimeCounters read.
1391void acc::DevicePtrOp::getEffects(
1393 &effects) {
1394 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1395 effects.emplace_back(MemoryEffects::Read::get(),
1397}
1398
1399// PresentOp: RuntimeCounters read+write.
1400void acc::PresentOp::getEffects(
1402 &effects) {
1403 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1404 effects.emplace_back(MemoryEffects::Write::get(),
1406 effects.emplace_back(MemoryEffects::Read::get(),
1408}
1409
1410// CopyinOp: RuntimeCounters read+write, var read, accVar result write.
1411void acc::CopyinOp::getEffects(
1413 &effects) {
1414 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1415 effects.emplace_back(MemoryEffects::Write::get(),
1417 effects.emplace_back(MemoryEffects::Read::get(),
1419 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1421}
1422
1423// CreateOp: RuntimeCounters read+write, accVar result write.
1424void acc::CreateOp::getEffects(
1426 &effects) {
1427 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1428 effects.emplace_back(MemoryEffects::Write::get(),
1430 effects.emplace_back(MemoryEffects::Read::get(),
1432 // TODO: should this be MemoryEffects::Allocate?
1434}
1435
1436// NoCreateOp: RuntimeCounters read+write.
1437void acc::NoCreateOp::getEffects(
1439 &effects) {
1440 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1441 effects.emplace_back(MemoryEffects::Write::get(),
1443 effects.emplace_back(MemoryEffects::Read::get(),
1445}
1446
1447// AttachOp: RuntimeCounters read+write, var read.
1448void acc::AttachOp::getEffects(
1450 &effects) {
1451 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1452 effects.emplace_back(MemoryEffects::Write::get(),
1454 effects.emplace_back(MemoryEffects::Read::get(),
1456 // TODO: should we also add MemoryEffects::Write?
1457 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1458}
1459
1460// GetDevicePtrOp: RuntimeCounters read.
1461void acc::GetDevicePtrOp::getEffects(
1463 &effects) {
1464 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1465 effects.emplace_back(MemoryEffects::Read::get(),
1467}
1468
1469// UpdateDeviceOp: var read, accVar result write.
1470void acc::UpdateDeviceOp::getEffects(
1472 &effects) {
1473 effects.emplace_back(MemoryEffects::Read::get(),
1475 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1477}
1478
1479// UseDeviceOp: RuntimeCounters read.
1480void acc::UseDeviceOp::getEffects(
1482 &effects) {
1483 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1484 effects.emplace_back(MemoryEffects::Read::get(),
1486}
1487
1488// DeclareDeviceResidentOp: RuntimeCounters write, var read.
1489void acc::DeclareDeviceResidentOp::getEffects(
1491 &effects) {
1492 effects.emplace_back(MemoryEffects::Write::get(),
1494 effects.emplace_back(MemoryEffects::Read::get(),
1496 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1497}
1498
1499// DeclareLinkOp: RuntimeCounters write, var read.
1500void acc::DeclareLinkOp::getEffects(
1502 &effects) {
1503 effects.emplace_back(MemoryEffects::Write::get(),
1505 effects.emplace_back(MemoryEffects::Read::get(),
1507 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1508}
1509
1510// CacheOp: NoMemoryEffect
1511void acc::CacheOp::getEffects(
1513 &effects) {}
1514
1515// CopyoutOp: RuntimeCounters read+write, accVar read, var write.
1516void acc::CopyoutOp::getEffects(
1518 &effects) {
1519 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1520 effects.emplace_back(MemoryEffects::Write::get(),
1522 effects.emplace_back(MemoryEffects::Read::get(),
1524 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1525 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1526}
1527
1528// DeleteOp: RuntimeCounters read+write, accVar read.
1529void acc::DeleteOp::getEffects(
1531 &effects) {
1532 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1533 effects.emplace_back(MemoryEffects::Write::get(),
1535 effects.emplace_back(MemoryEffects::Read::get(),
1537 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1538}
1539
1540// DetachOp: RuntimeCounters read+write, accVar read.
1541void acc::DetachOp::getEffects(
1543 &effects) {
1544 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1545 effects.emplace_back(MemoryEffects::Write::get(),
1547 effects.emplace_back(MemoryEffects::Read::get(),
1549 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1550}
1551
1552// UpdateHostOp: RuntimeCounters read+write, accVar read, var write.
1553void acc::UpdateHostOp::getEffects(
1555 &effects) {
1556 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1557 effects.emplace_back(MemoryEffects::Write::get(),
1559 effects.emplace_back(MemoryEffects::Read::get(),
1561 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1562 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1563}
1564
1565template <typename StructureOp>
1566static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
1567 unsigned nRegions = 1) {
1568
1570 for (unsigned i = 0; i < nRegions; ++i)
1571 regions.push_back(state.addRegion());
1572
1573 for (Region *region : regions)
1574 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
1575 return failure();
1576
1577 return success();
1578}
1579
1580namespace {
1581/// Pattern to remove operation without region that have constant false `ifCond`
1582/// and remove the condition from the operation if the `ifCond` is a true
1583/// constant.
1584template <typename OpTy>
1585struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1586 using OpRewritePattern<OpTy>::OpRewritePattern;
1587
1588 LogicalResult matchAndRewrite(OpTy op,
1589 PatternRewriter &rewriter) const override {
1590 // Early return if there is no condition.
1591 Value ifCond = op.getIfCond();
1592 if (!ifCond)
1593 return failure();
1594
1595 IntegerAttr constAttr;
1596 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1597 return failure();
1598 if (constAttr.getInt())
1599 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1600 else
1601 rewriter.eraseOp(op);
1602
1603 return success();
1604 }
1605};
1606
1607/// Replaces the given op with the contents of the given single-block region,
1608/// using the operands of the block terminator to replace operation results.
1609static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1610 Region &region, ValueRange blockArgs = {}) {
1611 assert(region.hasOneBlock() && "expected single-block region");
1612 Block *block = &region.front();
1613 Operation *terminator = block->getTerminator();
1614 ValueRange results = terminator->getOperands();
1615 rewriter.inlineBlockBefore(block, op, blockArgs);
1616 rewriter.replaceOp(op, results);
1617 rewriter.eraseOp(terminator);
1618}
1619
1620/// Pattern to remove operation with region that have constant false `ifCond`
1621/// and remove the condition from the operation if the `ifCond` is constant
1622/// true.
1623template <typename OpTy>
1624struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1625 using OpRewritePattern<OpTy>::OpRewritePattern;
1626
1627 LogicalResult matchAndRewrite(OpTy op,
1628 PatternRewriter &rewriter) const override {
1629 // Early return if there is no condition.
1630 Value ifCond = op.getIfCond();
1631 if (!ifCond)
1632 return failure();
1633
1634 IntegerAttr constAttr;
1635 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1636 return failure();
1637 if (constAttr.getInt())
1638 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1639 else
1640 replaceOpWithRegion(rewriter, op, op.getRegion());
1641
1642 return success();
1643 }
1644};
1645
1646//===----------------------------------------------------------------------===//
1647// Recipe Region Helpers
1648//===----------------------------------------------------------------------===//
1649
1650/// Create and populate an init region for privatization recipes.
1651/// Returns success if the region is populated, failure otherwise.
1652/// Sets needsFree to indicate if the allocated memory requires deallocation.
1653/// The `hostVar` is the original host variable used to derive
1654/// language-specific metadata via `genPrivateVariableInfo`.
1655/// The `varInfo` output parameter is set to the variable info produced.
1656static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1657 Region &initRegion, Value hostVar,
1658 StringRef varName, ValueRange bounds,
1659 bool &needsFree,
1660 acc::VariableInfoAttr &varInfo) {
1661 Type varType = hostVar.getType();
1662
1663 // Create init block with arguments: original value + bounds
1664 SmallVector<Type> argTypes{varType};
1665 SmallVector<Location> argLocs{loc};
1666 for (Value bound : bounds) {
1667 argTypes.push_back(bound.getType());
1668 argLocs.push_back(loc);
1669 }
1670
1671 Block *initBlock = builder.createBlock(&initRegion);
1672 initBlock->addArguments(argTypes, argLocs);
1673 builder.setInsertionPointToStart(initBlock);
1674
1675 Value privatizedValue;
1676
1677 // Get the block argument that represents the original variable
1678 Value blockArgVar = initBlock->getArgument(0);
1679
1680 // Generate init region body based on variable type
1681 if (isa<MappableType>(varType)) {
1682 auto mappableTy = cast<MappableType>(varType);
1683 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1684 auto typedHostVar = cast<TypedValue<MappableType>>(hostVar);
1685 varInfo = mappableTy.genPrivateVariableInfo(typedHostVar);
1686 privatizedValue = mappableTy.generatePrivateInit(
1687 builder, loc, typedVar, varName, bounds, {}, varInfo, needsFree);
1688 if (!privatizedValue)
1689 return failure();
1690 } else {
1691 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1692 auto pointerLikeTy = cast<PointerLikeType>(varType);
1693 // Use PointerLikeType's allocation API with the block argument
1694 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1695 blockArgVar, needsFree);
1696 if (!privatizedValue)
1697 return failure();
1698 }
1699
1700 // Add yield operation to init block
1701 acc::YieldOp::create(builder, loc, privatizedValue);
1702
1703 return success();
1704}
1705
1706/// Create and populate a copy region for firstprivate recipes.
1707/// Returns success if the region is populated, failure otherwise.
1708/// `varInfo` must be the attribute produced by `createInitRegion` for
1709/// `MappableType` (it is unused for `PointerLikeType` copy paths).
1710static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1711 Region &copyRegion, Type varType,
1712 ValueRange bounds,
1713 acc::VariableInfoAttr varInfo) {
1714 // Create copy block with arguments: original value + privatized value +
1715 // bounds
1716 SmallVector<Type> copyArgTypes{varType, varType};
1717 SmallVector<Location> copyArgLocs{loc, loc};
1718 for (Value bound : bounds) {
1719 copyArgTypes.push_back(bound.getType());
1720 copyArgLocs.push_back(loc);
1721 }
1722
1723 Block *copyBlock = builder.createBlock(&copyRegion);
1724 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1725 builder.setInsertionPointToStart(copyBlock);
1726
1727 Value originalArg = copyBlock->getArgument(0);
1728 Value privatizedArg = copyBlock->getArgument(1);
1729
1730 if (isa<MappableType>(varType)) {
1731 auto mappableTy = cast<MappableType>(varType);
1732 // generateCopy(src, dest): copy from original (arg0) into privatized
1733 // (arg1).
1734 if (!mappableTy.generateCopy(
1735 builder, loc, cast<TypedValue<MappableType>>(originalArg),
1736 cast<TypedValue<MappableType>>(privatizedArg), bounds, varInfo))
1737 return failure();
1738 } else {
1739 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1740 auto pointerLikeTy = cast<PointerLikeType>(varType);
1741 if (!pointerLikeTy.genCopy(
1742 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1743 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1744 return failure();
1745 }
1746
1747 // Add terminator to copy block
1748 acc::TerminatorOp::create(builder, loc);
1749
1750 return success();
1751}
1752
1753/// Create and populate a destroy region for privatization recipes.
1754/// Returns success if the region is populated, failure otherwise.
1755/// The `varInfo` carries language-specific metadata produced by
1756/// `createInitRegion`.
1757static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1758 Region &destroyRegion, Type varType,
1759 Value allocRes, ValueRange bounds,
1760 acc::VariableInfoAttr varInfo) {
1761 // Create destroy block with arguments: original value + privatized value +
1762 // bounds
1763 SmallVector<Type> destroyArgTypes{varType, varType};
1764 SmallVector<Location> destroyArgLocs{loc, loc};
1765 for (Value bound : bounds) {
1766 destroyArgTypes.push_back(bound.getType());
1767 destroyArgLocs.push_back(loc);
1768 }
1769
1770 Block *destroyBlock = builder.createBlock(&destroyRegion);
1771 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1772 builder.setInsertionPointToStart(destroyBlock);
1773
1774 auto varToFree =
1775 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1776 if (isa<MappableType>(varType)) {
1777 auto mappableTy = cast<MappableType>(varType);
1778 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds,
1779 varInfo))
1780 return failure();
1781 } else {
1782 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1783 auto pointerLikeTy = cast<PointerLikeType>(varType);
1784 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1785 return failure();
1786 }
1787
1788 acc::TerminatorOp::create(builder, loc);
1789 return success();
1790}
1791
1792} // namespace
1793
1794//===----------------------------------------------------------------------===//
1795// PrivateRecipeOp
1796//===----------------------------------------------------------------------===//
1797
1799 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1800 Type type, bool verifyYield, bool optional = false) {
1801 if (optional && region.empty())
1802 return success();
1803
1804 if (region.empty())
1805 return op->emitOpError() << "expects non-empty " << regionName << " region";
1806 Block &firstBlock = region.front();
1807 if (firstBlock.getNumArguments() < 1 ||
1808 firstBlock.getArgument(0).getType() != type)
1809 return op->emitOpError() << "expects " << regionName
1810 << " region first "
1811 "argument of the "
1812 << regionType << " type";
1813
1814 if (verifyYield) {
1815 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1816 if (yieldOp.getOperands().size() != 1 ||
1817 yieldOp.getOperands().getTypes()[0] != type)
1818 return op->emitOpError() << "expects " << regionName
1819 << " region to "
1820 "yield a value of the "
1821 << regionType << " type";
1822 }
1823 }
1824 return success();
1825}
1826
1827LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1828 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1829 "privatization", "init", getType(),
1830 /*verifyYield=*/false)))
1831 return failure();
1833 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1834 /*verifyYield=*/false, /*optional=*/true)))
1835 return failure();
1836 return success();
1837}
1838
1839std::optional<PrivateRecipeOp>
1840PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1841 StringRef recipeName, Value hostVar,
1842 StringRef varName, ValueRange bounds) {
1843 Type varType = hostVar.getType();
1844
1845 // First, validate that we can handle this variable type
1846 bool isMappable = isa<MappableType>(varType);
1847 bool isPointerLike = isa<PointerLikeType>(varType);
1848
1849 // Unsupported type
1850 if (!isMappable && !isPointerLike)
1851 return std::nullopt;
1852
1853 OpBuilder::InsertionGuard guard(builder);
1854
1855 // Create the recipe operation first so regions have proper parent context
1856 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1857
1858 // Populate the init region
1859 bool needsFree = false;
1860 acc::VariableInfoAttr varInfo;
1861 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1862 varName, bounds, needsFree, varInfo))) {
1863 recipe.erase();
1864 return std::nullopt;
1865 }
1866
1867 // Only create destroy region if the allocation needs deallocation
1868 if (needsFree) {
1869 // Extract the allocated value from the init block's yield operation
1870 auto yieldOp =
1871 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1872 Value allocRes = yieldOp.getOperand(0);
1873
1874 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1875 varType, allocRes, bounds, varInfo))) {
1876 recipe.erase();
1877 return std::nullopt;
1878 }
1879 }
1880
1881 return recipe;
1882}
1883
1884std::optional<PrivateRecipeOp>
1885PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1886 StringRef recipeName,
1887 FirstprivateRecipeOp firstprivRecipe) {
1888 // Create the private.recipe op with the same type as the firstprivate.recipe.
1889 OpBuilder::InsertionGuard guard(builder);
1890 auto varType = firstprivRecipe.getType();
1891 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1892
1893 // Clone the init region
1894 IRMapping mapping;
1895 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1896
1897 // Clone destroy region if the firstprivate.recipe has one.
1898 if (!firstprivRecipe.getDestroyRegion().empty()) {
1899 IRMapping mapping;
1900 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1901 mapping);
1902 }
1903 return recipe;
1904}
1905
1906//===----------------------------------------------------------------------===//
1907// FirstprivateRecipeOp
1908//===----------------------------------------------------------------------===//
1909
1910LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1911 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1912 "privatization", "init", getType(),
1913 /*verifyYield=*/false)))
1914 return failure();
1915
1916 if (getCopyRegion().empty())
1917 return emitOpError() << "expects non-empty copy region";
1918
1919 Block &firstBlock = getCopyRegion().front();
1920 if (firstBlock.getNumArguments() < 2 ||
1921 firstBlock.getArgument(0).getType() != getType())
1922 return emitOpError() << "expects copy region with two arguments of the "
1923 "privatization type";
1924
1925 if (getDestroyRegion().empty())
1926 return success();
1927
1928 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1929 "privatization", "destroy",
1930 getType(), /*verifyYield=*/false)))
1931 return failure();
1932
1933 return success();
1934}
1935
1936std::optional<FirstprivateRecipeOp>
1937FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1938 StringRef recipeName, Value hostVar,
1939 StringRef varName, ValueRange bounds) {
1940 Type varType = hostVar.getType();
1941
1942 // First, validate that we can handle this variable type
1943 bool isMappable = isa<MappableType>(varType);
1944 bool isPointerLike = isa<PointerLikeType>(varType);
1945
1946 // Unsupported type
1947 if (!isMappable && !isPointerLike)
1948 return std::nullopt;
1949
1950 OpBuilder::InsertionGuard guard(builder);
1951
1952 // Create the recipe operation first so regions have proper parent context
1953 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1954
1955 // Populate the init region
1956 bool needsFree = false;
1957 // Filled by createInitRegion for mappable variables (genPrivateVariableInfo);
1958 // then passed through to copy/destroy so generateCopy /
1959 // generatePrivateDestroy receive the same metadata as generatePrivateInit.
1960 acc::VariableInfoAttr varInfo;
1961 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1962 varName, bounds, needsFree, varInfo))) {
1963 recipe.erase();
1964 return std::nullopt;
1965 }
1966
1967 // Populate the copy region (uses varInfo for MappableType::generateCopy).
1968 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1969 bounds, varInfo))) {
1970 recipe.erase();
1971 return std::nullopt;
1972 }
1973
1974 // Only create destroy region if the allocation needs deallocation
1975 if (needsFree) {
1976 // Extract the allocated value from the init block's yield operation
1977 auto yieldOp =
1978 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1979 Value allocRes = yieldOp.getOperand(0);
1980
1981 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1982 varType, allocRes, bounds, varInfo))) {
1983 recipe.erase();
1984 return std::nullopt;
1985 }
1986 }
1987
1988 return recipe;
1989}
1990
1991//===----------------------------------------------------------------------===//
1992// ReductionRecipeOp
1993//===----------------------------------------------------------------------===//
1994
1995LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1996 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1997 "init", getType(),
1998 /*verifyYield=*/false)))
1999 return failure();
2000
2001 if (getCombinerRegion().empty())
2002 return emitOpError() << "expects non-empty combiner region";
2003
2004 Block &reductionBlock = getCombinerRegion().front();
2005 if (reductionBlock.getNumArguments() < 2 ||
2006 reductionBlock.getArgument(0).getType() != getType() ||
2007 reductionBlock.getArgument(1).getType() != getType())
2008 return emitOpError() << "expects combiner region with the first two "
2009 << "arguments of the reduction type";
2010
2011 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
2012 if (yieldOp.getOperands().size() != 1 ||
2013 yieldOp.getOperands().getTypes()[0] != getType())
2014 return emitOpError() << "expects combiner region to yield a value "
2015 "of the reduction type";
2016 }
2017
2018 return success();
2019}
2020
2021//===----------------------------------------------------------------------===//
2022// ParallelOp
2023//===----------------------------------------------------------------------===//
2024
2025/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
2026template <typename Op>
2027static LogicalResult checkDataOperands(Op op,
2028 const mlir::ValueRange &operands) {
2029 for (mlir::Value operand : operands)
2030 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
2031 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
2032 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
2033 operand.getDefiningOp()))
2034 return op.emitError(
2035 "expect data entry/exit operation or acc.getdeviceptr "
2036 "as defining op");
2037 return success();
2038}
2039
2040template <typename OpT, typename RecipeOpT>
2041static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
2042 const mlir::ValueRange &operands,
2043 llvm::StringRef operandName) {
2045 for (mlir::Value operand : operands) {
2046 if (!mlir::isa<OpT>(operand.getDefiningOp()))
2047 return accConstructOp->emitOpError()
2048 << "expected " << operandName << " as defining op";
2049 if (!set.insert(operand).second)
2050 return accConstructOp->emitOpError()
2051 << operandName << " operand appears more than once";
2052 }
2053 return success();
2054}
2055
2056unsigned ParallelOp::getNumDataOperands() {
2057 return getReductionOperands().size() + getPrivateOperands().size() +
2058 getFirstprivateOperands().size() + getDataClauseOperands().size();
2059}
2060
2061Value ParallelOp::getDataOperand(unsigned i) {
2062 unsigned numOptional = getAsyncOperands().size();
2063 numOptional += getNumGangs().size();
2064 numOptional += getNumWorkers().size();
2065 numOptional += getVectorLength().size();
2066 numOptional += getIfCond() ? 1 : 0;
2067 numOptional += getSelfCond() ? 1 : 0;
2068 return getOperand(getWaitOperands().size() + numOptional + i);
2069}
2070
2071template <typename Op>
2072static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
2073 ArrayAttr deviceTypes,
2074 llvm::StringRef keyword) {
2075 if (!operands.empty() &&
2076 (!deviceTypes || deviceTypes.getValue().size() != operands.size()))
2077 return op.emitOpError() << keyword << " operands count must match "
2078 << keyword << " device_type count";
2079 return success();
2080}
2081
2082template <typename Op>
2084 Op op, OperandRange operands, DenseI32ArrayAttr segments,
2085 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
2086 std::size_t numOperandsInSegments = 0;
2087 std::size_t nbOfSegments = 0;
2088
2089 if (segments) {
2090 for (auto segCount : segments.asArrayRef()) {
2091 if (maxInSegment != 0 && segCount > maxInSegment)
2092 return op.emitOpError() << keyword << " expects a maximum of "
2093 << maxInSegment << " values per segment";
2094 numOperandsInSegments += segCount;
2095 ++nbOfSegments;
2096 }
2097 }
2098
2099 if ((numOperandsInSegments != operands.size()) ||
2100 (!deviceTypes && !operands.empty()))
2101 return op.emitOpError()
2102 << keyword << " operand count does not match count in segments";
2103 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
2104 return op.emitOpError()
2105 << keyword << " segment count does not match device_type count";
2106 return success();
2107}
2108
2109LogicalResult acc::ParallelOp::verify() {
2110 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2111 mlir::acc::PrivateRecipeOp>(
2112 *this, getPrivateOperands(), "private")))
2113 return failure();
2114 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2115 mlir::acc::FirstprivateRecipeOp>(
2116 *this, getFirstprivateOperands(), "firstprivate")))
2117 return failure();
2118 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2119 mlir::acc::ReductionRecipeOp>(
2120 *this, getReductionOperands(), "reduction")))
2121 return failure();
2122
2124 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2125 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2126 return failure();
2127
2129 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2130 getWaitOperandsDeviceTypeAttr(), "wait")))
2131 return failure();
2132
2133 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2134 getNumWorkersDeviceTypeAttr(),
2135 "num_workers")))
2136 return failure();
2137
2138 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2139 getVectorLengthDeviceTypeAttr(),
2140 "vector_length")))
2141 return failure();
2142
2144 getAsyncOperandsDeviceTypeAttr(),
2145 "async")))
2146 return failure();
2147
2149 return failure();
2150
2151 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
2152}
2153
2154static mlir::Value
2155getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
2157 mlir::acc::DeviceType deviceType) {
2158 if (!arrayAttr)
2159 return {};
2160 if (auto pos = findSegment(*arrayAttr, deviceType))
2161 return range[*pos];
2162 return {};
2163}
2164
2165bool acc::ParallelOp::hasAsyncOnly() {
2166 return hasAsyncOnly(mlir::acc::DeviceType::None);
2167}
2168
2169bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2170 return hasDeviceType(getAsyncOnly(), deviceType);
2171}
2172
2173mlir::Value acc::ParallelOp::getAsyncValue() {
2174 return getAsyncValue(mlir::acc::DeviceType::None);
2175}
2176
2177mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2179 getAsyncOperands(), deviceType);
2180}
2181
2182mlir::Value acc::ParallelOp::getNumWorkersValue() {
2183 return getNumWorkersValue(mlir::acc::DeviceType::None);
2184}
2185
2187acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2188 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2189 deviceType);
2190}
2191
2192mlir::Value acc::ParallelOp::getVectorLengthValue() {
2193 return getVectorLengthValue(mlir::acc::DeviceType::None);
2194}
2195
2197acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2198 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2199 getVectorLength(), deviceType);
2200}
2201
2202mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
2203 return getNumGangsValues(mlir::acc::DeviceType::None);
2204}
2205
2207ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2208 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2209 getNumGangsSegments(), deviceType);
2210}
2211
2213 std::optional<mlir::ArrayAttr> numGangsDeviceType,
2215 std::optional<llvm::ArrayRef<int32_t>> numGangsSegments,
2216 std::optional<mlir::ArrayAttr> numWorkersDeviceType,
2218 std::optional<mlir::ArrayAttr> vectorLengthDeviceType,
2219 mlir::Operation::operand_range vectorLength,
2220 mlir::acc::DeviceType deviceType) {
2221 return !getValuesFromSegments(numGangsDeviceType, numGangs, numGangsSegments,
2222 deviceType)
2223 .empty() ||
2224 getValueInDeviceTypeSegment(numWorkersDeviceType, numWorkers,
2225 deviceType) ||
2226 getValueInDeviceTypeSegment(vectorLengthDeviceType, vectorLength,
2227 deviceType);
2228}
2229
2230bool acc::ParallelOp::hasAnyGangWorkerVector(mlir::acc::DeviceType deviceType) {
2232 getNumGangsDeviceType(), getNumGangs(), getNumGangsSegments(),
2233 getNumWorkersDeviceType(), getNumWorkers(), getVectorLengthDeviceType(),
2234 getVectorLength(), deviceType);
2235}
2236
2237bool acc::ParallelOp::hasWaitOnly() {
2238 return hasWaitOnly(mlir::acc::DeviceType::None);
2239}
2240
2241bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2242 return hasDeviceType(getWaitOnly(), deviceType);
2243}
2244
2245mlir::Operation::operand_range ParallelOp::getWaitValues() {
2246 return getWaitValues(mlir::acc::DeviceType::None);
2247}
2248
2250ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2252 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2253 getHasWaitDevnum(), deviceType);
2254}
2255
2256mlir::Value ParallelOp::getWaitDevnum() {
2257 return getWaitDevnum(mlir::acc::DeviceType::None);
2258}
2259
2260mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2261 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2262 getWaitOperandsSegments(), getHasWaitDevnum(),
2263 deviceType);
2264}
2265
2266void ParallelOp::build(mlir::OpBuilder &odsBuilder,
2267 mlir::OperationState &odsState,
2268 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
2269 mlir::ValueRange vectorLength,
2270 mlir::ValueRange asyncOperands,
2271 mlir::ValueRange waitOperands, mlir::Value ifCond,
2272 mlir::Value selfCond, mlir::ValueRange reductionOperands,
2273 mlir::ValueRange gangPrivateOperands,
2274 mlir::ValueRange gangFirstPrivateOperands,
2275 mlir::ValueRange dataClauseOperands) {
2276 ParallelOp::build(
2277 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
2278 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
2279 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
2280 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
2281 /*numGangsDeviceType=*/nullptr, numWorkers,
2282 /*numWorkersDeviceType=*/nullptr, vectorLength,
2283 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
2284 /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
2285 gangFirstPrivateOperands, dataClauseOperands,
2286 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
2287}
2288
2289void acc::ParallelOp::addNumWorkersOperand(
2290 MLIRContext *context, mlir::Value newValue,
2291 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2292 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2293 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2294 getNumWorkersMutable()));
2295}
2296void acc::ParallelOp::addVectorLengthOperand(
2297 MLIRContext *context, mlir::Value newValue,
2298 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2299 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2300 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2301 getVectorLengthMutable()));
2302}
2303
2304void acc::ParallelOp::addAsyncOnly(
2305 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2306 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2307 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2308}
2309
2310void acc::ParallelOp::addAsyncOperand(
2311 MLIRContext *context, mlir::Value newValue,
2312 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2313 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2314 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2315 getAsyncOperandsMutable()));
2316}
2317
2318void acc::ParallelOp::addNumGangsOperands(
2319 MLIRContext *context, mlir::ValueRange newValues,
2320 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2322 if (getNumGangsSegments())
2323 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2324
2325 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2326 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2327 getNumGangsMutable(), segments));
2328
2329 setNumGangsSegments(segments);
2330}
2331void acc::ParallelOp::addWaitOnly(
2332 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2333 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2334 effectiveDeviceTypes));
2335}
2336void acc::ParallelOp::addWaitOperands(
2337 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2338 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2339
2341 if (getWaitOperandsSegments())
2342 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2343
2344 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2345 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2346 getWaitOperandsMutable(), segments));
2347 setWaitOperandsSegments(segments);
2348
2350 if (getHasWaitDevnumAttr())
2351 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2352 hasDevnums.insert(
2353 hasDevnums.end(),
2354 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2355 mlir::BoolAttr::get(context, hasDevnum));
2356 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2357}
2358
2359void acc::ParallelOp::addPrivatization(MLIRContext *context,
2360 mlir::acc::PrivateOp op,
2361 mlir::acc::PrivateRecipeOp recipe) {
2362 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2363 getPrivateOperandsMutable().append(op.getResult());
2364}
2365
2366void acc::ParallelOp::addFirstPrivatization(
2367 MLIRContext *context, mlir::acc::FirstprivateOp op,
2368 mlir::acc::FirstprivateRecipeOp recipe) {
2369 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2370 getFirstprivateOperandsMutable().append(op.getResult());
2371}
2372
2373void acc::ParallelOp::addReduction(MLIRContext *context,
2374 mlir::acc::ReductionOp op,
2375 mlir::acc::ReductionRecipeOp recipe) {
2376 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2377 getReductionOperandsMutable().append(op.getResult());
2378}
2379
2380static ParseResult parseNumGangs(
2381 mlir::OpAsmParser &parser,
2383 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2384 mlir::DenseI32ArrayAttr &segments) {
2387
2388 do {
2389 if (failed(parser.parseLBrace()))
2390 return failure();
2391
2392 int32_t crtOperandsSize = operands.size();
2393 if (failed(parser.parseCommaSeparatedList(
2395 if (parser.parseOperand(operands.emplace_back()) ||
2396 parser.parseColonType(types.emplace_back()))
2397 return failure();
2398 return success();
2399 })))
2400 return failure();
2401 seg.push_back(operands.size() - crtOperandsSize);
2402
2403 if (failed(parser.parseRBrace()))
2404 return failure();
2405
2406 if (succeeded(parser.parseOptionalLSquare())) {
2407 if (parser.parseAttribute(attributes.emplace_back()) ||
2408 parser.parseRSquare())
2409 return failure();
2410 } else {
2411 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2412 parser.getContext(), mlir::acc::DeviceType::None));
2413 }
2414 } while (succeeded(parser.parseOptionalComma()));
2415
2416 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2417 attributes.end());
2418 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2419 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2420
2421 return success();
2422}
2423
2425 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2426 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2427 p << " [" << attr << "]";
2428}
2429
2431 mlir::OperandRange operands, mlir::TypeRange types,
2432 std::optional<mlir::ArrayAttr> deviceTypes,
2433 std::optional<mlir::DenseI32ArrayAttr> segments) {
2434 unsigned opIdx = 0;
2435 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2436 p << "{";
2437 llvm::interleaveComma(
2438 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2439 p << operands[opIdx] << " : " << operands[opIdx].getType();
2440 ++opIdx;
2441 });
2442 p << "}";
2443 printSingleDeviceType(p, it.value());
2444 });
2445}
2446
2448 mlir::OpAsmParser &parser,
2450 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2451 mlir::DenseI32ArrayAttr &segments) {
2454
2455 do {
2456 if (failed(parser.parseLBrace()))
2457 return failure();
2458
2459 int32_t crtOperandsSize = operands.size();
2460
2461 if (failed(parser.parseCommaSeparatedList(
2463 if (parser.parseOperand(operands.emplace_back()) ||
2464 parser.parseColonType(types.emplace_back()))
2465 return failure();
2466 return success();
2467 })))
2468 return failure();
2469
2470 seg.push_back(operands.size() - crtOperandsSize);
2471
2472 if (failed(parser.parseRBrace()))
2473 return failure();
2474
2475 if (succeeded(parser.parseOptionalLSquare())) {
2476 if (parser.parseAttribute(attributes.emplace_back()) ||
2477 parser.parseRSquare())
2478 return failure();
2479 } else {
2480 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2481 parser.getContext(), mlir::acc::DeviceType::None));
2482 }
2483 } while (succeeded(parser.parseOptionalComma()));
2484
2485 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2486 attributes.end());
2487 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2488 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2489
2490 return success();
2491}
2492
2495 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2496 std::optional<mlir::DenseI32ArrayAttr> segments) {
2497 unsigned opIdx = 0;
2498 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2499 p << "{";
2500 llvm::interleaveComma(
2501 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2502 p << operands[opIdx] << " : " << operands[opIdx].getType();
2503 ++opIdx;
2504 });
2505 p << "}";
2506 printSingleDeviceType(p, it.value());
2507 });
2508}
2509
2510static ParseResult parseWaitClause(
2511 mlir::OpAsmParser &parser,
2513 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2514 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2515 mlir::ArrayAttr &keywordOnly) {
2516 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2518
2519 bool needCommaBeforeOperands = false;
2520
2521 // Keyword only
2522 if (failed(parser.parseOptionalLParen())) {
2523 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2524 parser.getContext(), mlir::acc::DeviceType::None));
2525 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2526 return success();
2527 }
2528
2529 // Parse keyword only attributes
2530 if (succeeded(parser.parseOptionalLSquare())) {
2531 if (failed(parser.parseCommaSeparatedList([&]() {
2532 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2533 return failure();
2534 return success();
2535 })))
2536 return failure();
2537 if (parser.parseRSquare())
2538 return failure();
2539 needCommaBeforeOperands = true;
2540 }
2541
2542 if (needCommaBeforeOperands && failed(parser.parseComma()))
2543 return failure();
2544
2545 do {
2546 if (failed(parser.parseLBrace()))
2547 return failure();
2548
2549 int32_t crtOperandsSize = operands.size();
2550
2551 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2552 if (failed(parser.parseColon()))
2553 return failure();
2554 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2555 } else {
2556 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2557 }
2558
2559 if (failed(parser.parseCommaSeparatedList(
2561 if (parser.parseOperand(operands.emplace_back()) ||
2562 parser.parseColonType(types.emplace_back()))
2563 return failure();
2564 return success();
2565 })))
2566 return failure();
2567
2568 seg.push_back(operands.size() - crtOperandsSize);
2569
2570 if (failed(parser.parseRBrace()))
2571 return failure();
2572
2573 if (succeeded(parser.parseOptionalLSquare())) {
2574 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2575 parser.parseRSquare())
2576 return failure();
2577 } else {
2578 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2579 parser.getContext(), mlir::acc::DeviceType::None));
2580 }
2581 } while (succeeded(parser.parseOptionalComma()));
2582
2583 if (failed(parser.parseRParen()))
2584 return failure();
2585
2586 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2587 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2588 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2589 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2590
2591 return success();
2592}
2593
2594static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2595 if (!hasDeviceTypeValues(attrs))
2596 return false;
2597 if (attrs->size() != 1)
2598 return false;
2599 if (auto deviceTypeAttr =
2600 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2601 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2602 return false;
2603}
2604
2606 mlir::OperandRange operands, mlir::TypeRange types,
2607 std::optional<mlir::ArrayAttr> deviceTypes,
2608 std::optional<mlir::DenseI32ArrayAttr> segments,
2609 std::optional<mlir::ArrayAttr> hasDevNum,
2610 std::optional<mlir::ArrayAttr> keywordOnly) {
2611
2612 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2613 return;
2614
2615 p << "(";
2616
2617 printDeviceTypes(p, keywordOnly);
2618 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2619 p << ", ";
2620
2621 if (hasDeviceTypeValues(deviceTypes)) {
2622 unsigned opIdx = 0;
2623 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2624 p << "{";
2625 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2626 if (boolAttr && boolAttr.getValue())
2627 p << "devnum: ";
2628 llvm::interleaveComma(
2629 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2630 p << operands[opIdx] << " : " << operands[opIdx].getType();
2631 ++opIdx;
2632 });
2633 p << "}";
2634 printSingleDeviceType(p, it.value());
2635 });
2636 }
2637
2638 p << ")";
2639}
2640
2641static ParseResult parseDeviceTypeOperands(
2642 mlir::OpAsmParser &parser,
2644 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2646 if (failed(parser.parseCommaSeparatedList([&]() {
2647 if (parser.parseOperand(operands.emplace_back()) ||
2648 parser.parseColonType(types.emplace_back()))
2649 return failure();
2650 if (succeeded(parser.parseOptionalLSquare())) {
2651 if (parser.parseAttribute(attributes.emplace_back()) ||
2652 parser.parseRSquare())
2653 return failure();
2654 } else {
2655 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2656 parser.getContext(), mlir::acc::DeviceType::None));
2657 }
2658 return success();
2659 })))
2660 return failure();
2661 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2662 attributes.end());
2663 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2664 return success();
2665}
2666
2667static void
2669 mlir::OperandRange operands, mlir::TypeRange types,
2670 std::optional<mlir::ArrayAttr> deviceTypes) {
2671 if (!hasDeviceTypeValues(deviceTypes))
2672 return;
2673 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2674 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2675 printSingleDeviceType(p, std::get<0>(it));
2676 });
2677}
2678
2680 mlir::OpAsmParser &parser,
2682 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2683 mlir::ArrayAttr &keywordOnlyDeviceType) {
2684
2685 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2686 bool needCommaBeforeOperands = false;
2687
2688 if (failed(parser.parseOptionalLParen())) {
2689 // Keyword only
2690 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2691 parser.getContext(), mlir::acc::DeviceType::None));
2692 keywordOnlyDeviceType =
2693 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2694 return success();
2695 }
2696
2697 // Parse keyword only attributes
2698 if (succeeded(parser.parseOptionalLSquare())) {
2699 // Parse keyword only attributes
2700 if (failed(parser.parseCommaSeparatedList([&]() {
2701 if (parser.parseAttribute(
2702 keywordOnlyDeviceTypeAttributes.emplace_back()))
2703 return failure();
2704 return success();
2705 })))
2706 return failure();
2707 if (parser.parseRSquare())
2708 return failure();
2709 needCommaBeforeOperands = true;
2710 }
2711
2712 if (needCommaBeforeOperands && failed(parser.parseComma()))
2713 return failure();
2714
2716 if (failed(parser.parseCommaSeparatedList([&]() {
2717 if (parser.parseOperand(operands.emplace_back()) ||
2718 parser.parseColonType(types.emplace_back()))
2719 return failure();
2720 if (succeeded(parser.parseOptionalLSquare())) {
2721 if (parser.parseAttribute(attributes.emplace_back()) ||
2722 parser.parseRSquare())
2723 return failure();
2724 } else {
2725 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2726 parser.getContext(), mlir::acc::DeviceType::None));
2727 }
2728 return success();
2729 })))
2730 return failure();
2731
2732 if (failed(parser.parseRParen()))
2733 return failure();
2734
2735 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2736 attributes.end());
2737 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2738 return success();
2739}
2740
2743 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2744 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2745
2746 if (operands.begin() == operands.end() &&
2747 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2748 return;
2749 }
2750
2751 p << "(";
2752 printDeviceTypes(p, keywordOnlyDeviceTypes);
2753 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2754 hasDeviceTypeValues(deviceTypes))
2755 p << ", ";
2756 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2757 p << ")";
2758}
2759
2761 mlir::OpAsmParser &parser,
2762 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2763 mlir::Type &operandType, mlir::UnitAttr &attr) {
2764 // Keyword only
2765 if (failed(parser.parseOptionalLParen())) {
2766 attr = mlir::UnitAttr::get(parser.getContext());
2767 return success();
2768 }
2769
2771 if (failed(parser.parseOperand(op)))
2772 return failure();
2773 operand = op;
2774 if (failed(parser.parseColon()))
2775 return failure();
2776 if (failed(parser.parseType(operandType)))
2777 return failure();
2778 if (failed(parser.parseRParen()))
2779 return failure();
2780
2781 return success();
2782}
2783
2785 mlir::Operation *op,
2786 std::optional<mlir::Value> operand,
2787 mlir::Type operandType,
2788 mlir::UnitAttr attr) {
2789 if (attr)
2790 return;
2791
2792 p << "(";
2793 p.printOperand(*operand);
2794 p << " : ";
2795 p.printType(operandType);
2796 p << ")";
2797}
2798
2800 mlir::OpAsmParser &parser,
2802 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2803 // Keyword only
2804 if (failed(parser.parseOptionalLParen())) {
2805 attr = mlir::UnitAttr::get(parser.getContext());
2806 return success();
2807 }
2808
2809 if (failed(parser.parseCommaSeparatedList([&]() {
2810 if (parser.parseOperand(operands.emplace_back()))
2811 return failure();
2812 return success();
2813 })))
2814 return failure();
2815 if (failed(parser.parseColon()))
2816 return failure();
2817 if (failed(parser.parseCommaSeparatedList([&]() {
2818 if (parser.parseType(types.emplace_back()))
2819 return failure();
2820 return success();
2821 })))
2822 return failure();
2823 if (failed(parser.parseRParen()))
2824 return failure();
2825
2826 return success();
2827}
2828
2830 mlir::Operation *op,
2831 mlir::OperandRange operands,
2832 mlir::TypeRange types,
2833 mlir::UnitAttr attr) {
2834 if (attr)
2835 return;
2836
2837 p << "(";
2838 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2839 p << " : ";
2840 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2841 p << ")";
2842}
2843
2844static ParseResult
2846 mlir::acc::CombinedConstructsTypeAttr &attr) {
2847 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2848 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2849 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2850 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2851 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2852 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2853 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2854 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2855 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2856 } else {
2857 parser.emitError(parser.getCurrentLocation(),
2858 "expected compute construct name");
2859 return failure();
2860 }
2861 return success();
2862}
2863
2864static void
2866 mlir::acc::CombinedConstructsTypeAttr attr) {
2867 if (attr) {
2868 switch (attr.getValue()) {
2869 case mlir::acc::CombinedConstructsType::KernelsLoop:
2870 p << "kernels";
2871 break;
2872 case mlir::acc::CombinedConstructsType::ParallelLoop:
2873 p << "parallel";
2874 break;
2875 case mlir::acc::CombinedConstructsType::SerialLoop:
2876 p << "serial";
2877 break;
2878 };
2879 }
2880}
2881
2882//===----------------------------------------------------------------------===//
2883// SerialOp
2884//===----------------------------------------------------------------------===//
2885
2886unsigned SerialOp::getNumDataOperands() {
2887 return getReductionOperands().size() + getPrivateOperands().size() +
2888 getFirstprivateOperands().size() + getDataClauseOperands().size();
2889}
2890
2891Value SerialOp::getDataOperand(unsigned i) {
2892 unsigned numOptional = getAsyncOperands().size();
2893 numOptional += getIfCond() ? 1 : 0;
2894 numOptional += getSelfCond() ? 1 : 0;
2895 return getOperand(getWaitOperands().size() + numOptional + i);
2896}
2897
2898bool acc::SerialOp::hasAsyncOnly() {
2899 return hasAsyncOnly(mlir::acc::DeviceType::None);
2900}
2901
2902bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2903 return hasDeviceType(getAsyncOnly(), deviceType);
2904}
2905
2906mlir::Value acc::SerialOp::getAsyncValue() {
2907 return getAsyncValue(mlir::acc::DeviceType::None);
2908}
2909
2910mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2912 getAsyncOperands(), deviceType);
2913}
2914
2915bool acc::SerialOp::hasWaitOnly() {
2916 return hasWaitOnly(mlir::acc::DeviceType::None);
2917}
2918
2919bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2920 return hasDeviceType(getWaitOnly(), deviceType);
2921}
2922
2923mlir::Operation::operand_range SerialOp::getWaitValues() {
2924 return getWaitValues(mlir::acc::DeviceType::None);
2925}
2926
2928SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2930 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2931 getHasWaitDevnum(), deviceType);
2932}
2933
2934mlir::Value SerialOp::getWaitDevnum() {
2935 return getWaitDevnum(mlir::acc::DeviceType::None);
2936}
2937
2938mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2939 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2940 getWaitOperandsSegments(), getHasWaitDevnum(),
2941 deviceType);
2942}
2943
2944LogicalResult acc::SerialOp::verify() {
2945 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2946 mlir::acc::PrivateRecipeOp>(
2947 *this, getPrivateOperands(), "private")))
2948 return failure();
2949 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2950 mlir::acc::FirstprivateRecipeOp>(
2951 *this, getFirstprivateOperands(), "firstprivate")))
2952 return failure();
2953 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2954 mlir::acc::ReductionRecipeOp>(
2955 *this, getReductionOperands(), "reduction")))
2956 return failure();
2957
2959 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2960 getWaitOperandsDeviceTypeAttr(), "wait")))
2961 return failure();
2962
2964 getAsyncOperandsDeviceTypeAttr(),
2965 "async")))
2966 return failure();
2967
2969 return failure();
2970
2971 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2972}
2973
2974void acc::SerialOp::addAsyncOnly(
2975 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2976 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2977 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2978}
2979
2980void acc::SerialOp::addAsyncOperand(
2981 MLIRContext *context, mlir::Value newValue,
2982 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2983 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2984 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2985 getAsyncOperandsMutable()));
2986}
2987
2988void acc::SerialOp::addWaitOnly(
2989 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2990 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2991 effectiveDeviceTypes));
2992}
2993void acc::SerialOp::addWaitOperands(
2994 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2995 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2996
2998 if (getWaitOperandsSegments())
2999 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3000
3001 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3002 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3003 getWaitOperandsMutable(), segments));
3004 setWaitOperandsSegments(segments);
3005
3007 if (getHasWaitDevnumAttr())
3008 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3009 hasDevnums.insert(
3010 hasDevnums.end(),
3011 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3012 mlir::BoolAttr::get(context, hasDevnum));
3013 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3014}
3015
3016void acc::SerialOp::addPrivatization(MLIRContext *context,
3017 mlir::acc::PrivateOp op,
3018 mlir::acc::PrivateRecipeOp recipe) {
3019 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3020 getPrivateOperandsMutable().append(op.getResult());
3021}
3022
3023void acc::SerialOp::addFirstPrivatization(
3024 MLIRContext *context, mlir::acc::FirstprivateOp op,
3025 mlir::acc::FirstprivateRecipeOp recipe) {
3026 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3027 getFirstprivateOperandsMutable().append(op.getResult());
3028}
3029
3030void acc::SerialOp::addReduction(MLIRContext *context,
3031 mlir::acc::ReductionOp op,
3032 mlir::acc::ReductionRecipeOp recipe) {
3033 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3034 getReductionOperandsMutable().append(op.getResult());
3035}
3036
3037//===----------------------------------------------------------------------===//
3038// KernelsOp
3039//===----------------------------------------------------------------------===//
3040
3041unsigned KernelsOp::getNumDataOperands() {
3042 return getDataClauseOperands().size();
3043}
3044
3045Value KernelsOp::getDataOperand(unsigned i) {
3046 unsigned numOptional = getAsyncOperands().size();
3047 numOptional += getWaitOperands().size();
3048 numOptional += getNumGangs().size();
3049 numOptional += getNumWorkers().size();
3050 numOptional += getVectorLength().size();
3051 numOptional += getIfCond() ? 1 : 0;
3052 numOptional += getSelfCond() ? 1 : 0;
3053 return getOperand(numOptional + i);
3054}
3055
3056bool acc::KernelsOp::hasAsyncOnly() {
3057 return hasAsyncOnly(mlir::acc::DeviceType::None);
3058}
3059
3060bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3061 return hasDeviceType(getAsyncOnly(), deviceType);
3062}
3063
3064mlir::Value acc::KernelsOp::getAsyncValue() {
3065 return getAsyncValue(mlir::acc::DeviceType::None);
3066}
3067
3068mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3070 getAsyncOperands(), deviceType);
3071}
3072
3073mlir::Value acc::KernelsOp::getNumWorkersValue() {
3074 return getNumWorkersValue(mlir::acc::DeviceType::None);
3075}
3076
3078acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
3079 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
3080 deviceType);
3081}
3082
3083mlir::Value acc::KernelsOp::getVectorLengthValue() {
3084 return getVectorLengthValue(mlir::acc::DeviceType::None);
3085}
3086
3088acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
3089 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
3090 getVectorLength(), deviceType);
3091}
3092
3093mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
3094 return getNumGangsValues(mlir::acc::DeviceType::None);
3095}
3096
3098KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
3099 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
3100 getNumGangsSegments(), deviceType);
3101}
3102
3103bool acc::KernelsOp::hasAnyGangWorkerVector(mlir::acc::DeviceType deviceType) {
3105 getNumGangsDeviceType(), getNumGangs(), getNumGangsSegments(),
3106 getNumWorkersDeviceType(), getNumWorkers(), getVectorLengthDeviceType(),
3107 getVectorLength(), deviceType);
3108}
3109
3110bool acc::KernelsOp::hasWaitOnly() {
3111 return hasWaitOnly(mlir::acc::DeviceType::None);
3112}
3113
3114bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3115 return hasDeviceType(getWaitOnly(), deviceType);
3116}
3117
3118mlir::Operation::operand_range KernelsOp::getWaitValues() {
3119 return getWaitValues(mlir::acc::DeviceType::None);
3120}
3121
3123KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3125 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3126 getHasWaitDevnum(), deviceType);
3127}
3128
3129mlir::Value KernelsOp::getWaitDevnum() {
3130 return getWaitDevnum(mlir::acc::DeviceType::None);
3131}
3132
3133mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3134 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3135 getWaitOperandsSegments(), getHasWaitDevnum(),
3136 deviceType);
3137}
3138
3139LogicalResult acc::KernelsOp::verify() {
3141 *this, getNumGangs(), getNumGangsSegmentsAttr(),
3142 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
3143 return failure();
3144
3146 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3147 getWaitOperandsDeviceTypeAttr(), "wait")))
3148 return failure();
3149
3150 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
3151 getNumWorkersDeviceTypeAttr(),
3152 "num_workers")))
3153 return failure();
3154
3155 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
3156 getVectorLengthDeviceTypeAttr(),
3157 "vector_length")))
3158 return failure();
3159
3161 getAsyncOperandsDeviceTypeAttr(),
3162 "async")))
3163 return failure();
3164
3166 return failure();
3167
3168 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
3169}
3170
3171void acc::KernelsOp::addPrivatization(MLIRContext *context,
3172 mlir::acc::PrivateOp op,
3173 mlir::acc::PrivateRecipeOp recipe) {
3174 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3175 getPrivateOperandsMutable().append(op.getResult());
3176}
3177
3178void acc::KernelsOp::addFirstPrivatization(
3179 MLIRContext *context, mlir::acc::FirstprivateOp op,
3180 mlir::acc::FirstprivateRecipeOp recipe) {
3181 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3182 getFirstprivateOperandsMutable().append(op.getResult());
3183}
3184
3185void acc::KernelsOp::addReduction(MLIRContext *context,
3186 mlir::acc::ReductionOp op,
3187 mlir::acc::ReductionRecipeOp recipe) {
3188 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3189 getReductionOperandsMutable().append(op.getResult());
3190}
3191
3192void acc::KernelsOp::addNumWorkersOperand(
3193 MLIRContext *context, mlir::Value newValue,
3194 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3195 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3196 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3197 getNumWorkersMutable()));
3198}
3199
3200void acc::KernelsOp::addVectorLengthOperand(
3201 MLIRContext *context, mlir::Value newValue,
3202 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3203 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3204 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3205 getVectorLengthMutable()));
3206}
3207void acc::KernelsOp::addAsyncOnly(
3208 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3209 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3210 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3211}
3212
3213void acc::KernelsOp::addAsyncOperand(
3214 MLIRContext *context, mlir::Value newValue,
3215 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3216 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3217 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3218 getAsyncOperandsMutable()));
3219}
3220
3221void acc::KernelsOp::addNumGangsOperands(
3222 MLIRContext *context, mlir::ValueRange newValues,
3223 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3225 if (getNumGangsSegmentsAttr())
3226 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3227
3228 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3229 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3230 getNumGangsMutable(), segments));
3231
3232 setNumGangsSegments(segments);
3233}
3234
3235void acc::KernelsOp::addWaitOnly(
3236 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3237 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3238 effectiveDeviceTypes));
3239}
3240void acc::KernelsOp::addWaitOperands(
3241 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3242 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3243
3245 if (getWaitOperandsSegments())
3246 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3247
3248 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3249 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3250 getWaitOperandsMutable(), segments));
3251 setWaitOperandsSegments(segments);
3252
3254 if (getHasWaitDevnumAttr())
3255 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3256 hasDevnums.insert(
3257 hasDevnums.end(),
3258 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3259 mlir::BoolAttr::get(context, hasDevnum));
3260 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3261}
3262
3263//===----------------------------------------------------------------------===//
3264// HostDataOp
3265//===----------------------------------------------------------------------===//
3266
3267LogicalResult acc::HostDataOp::verify() {
3268 if (getDataClauseOperands().empty())
3269 return emitError("at least one operand must appear on the host_data "
3270 "operation");
3271
3273 for (mlir::Value operand : getDataClauseOperands()) {
3274 auto useDeviceOp =
3275 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3276 if (!useDeviceOp)
3277 return emitError("expect data entry operation as defining op");
3278
3279 // Check for duplicate use_device clauses
3280 if (!seenVars.insert(useDeviceOp.getVar()).second)
3281 return emitError("duplicate use_device variable");
3282 }
3283 return success();
3284}
3285
3286void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3287 MLIRContext *context) {
3288 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3289}
3290
3291//===----------------------------------------------------------------------===//
3292// LoopOp
3293//===----------------------------------------------------------------------===//
3294
3295static ParseResult parseGangValue(
3296 OpAsmParser &parser, llvm::StringRef keyword,
3299 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
3300 bool &needCommaBetweenValues, bool &newValue) {
3301 if (succeeded(parser.parseOptionalKeyword(keyword))) {
3302 if (parser.parseEqual())
3303 return failure();
3304 if (parser.parseOperand(operands.emplace_back()) ||
3305 parser.parseColonType(types.emplace_back()))
3306 return failure();
3307 attributes.push_back(gangArgType);
3308 needCommaBetweenValues = true;
3309 newValue = true;
3310 }
3311 return success();
3312}
3313
3314static ParseResult parseGangClause(
3315 OpAsmParser &parser,
3317 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
3318 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
3319 mlir::ArrayAttr &gangOnlyDeviceType) {
3320 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
3321 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
3322 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
3324 bool needCommaBetweenValues = false;
3325 bool needCommaBeforeOperands = false;
3326
3327 if (failed(parser.parseOptionalLParen())) {
3328 // Gang only keyword
3329 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3330 parser.getContext(), mlir::acc::DeviceType::None));
3331 gangOnlyDeviceType =
3332 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
3333 return success();
3334 }
3335
3336 // Parse gang only attributes
3337 if (succeeded(parser.parseOptionalLSquare())) {
3338 // Parse gang only attributes
3339 if (failed(parser.parseCommaSeparatedList([&]() {
3340 if (parser.parseAttribute(
3341 gangOnlyDeviceTypeAttributes.emplace_back()))
3342 return failure();
3343 return success();
3344 })))
3345 return failure();
3346 if (parser.parseRSquare())
3347 return failure();
3348 needCommaBeforeOperands = true;
3349 }
3350
3351 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3352 mlir::acc::GangArgType::Num);
3353 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3354 mlir::acc::GangArgType::Dim);
3355 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3356 parser.getContext(), mlir::acc::GangArgType::Static);
3357
3358 do {
3359 if (needCommaBeforeOperands) {
3360 needCommaBeforeOperands = false;
3361 continue;
3362 }
3363
3364 if (failed(parser.parseLBrace()))
3365 return failure();
3366
3367 int32_t crtOperandsSize = gangOperands.size();
3368 while (true) {
3369 bool newValue = false;
3370 bool needValue = false;
3371 if (needCommaBetweenValues) {
3372 if (succeeded(parser.parseOptionalComma()))
3373 needValue = true; // expect a new value after comma.
3374 else
3375 break;
3376 }
3377
3378 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
3379 gangOperands, gangOperandsType,
3380 gangArgTypeAttributes, argNum,
3381 needCommaBetweenValues, newValue)))
3382 return failure();
3383 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
3384 gangOperands, gangOperandsType,
3385 gangArgTypeAttributes, argDim,
3386 needCommaBetweenValues, newValue)))
3387 return failure();
3388 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3389 gangOperands, gangOperandsType,
3390 gangArgTypeAttributes, argStatic,
3391 needCommaBetweenValues, newValue)))
3392 return failure();
3393
3394 if (!newValue && needValue) {
3395 parser.emitError(parser.getCurrentLocation(),
3396 "new value expected after comma");
3397 return failure();
3398 }
3399
3400 if (!newValue)
3401 break;
3402 }
3403
3404 if (gangOperands.empty())
3405 return parser.emitError(
3406 parser.getCurrentLocation(),
3407 "expect at least one of num, dim or static values");
3408
3409 if (failed(parser.parseRBrace()))
3410 return failure();
3411
3412 if (succeeded(parser.parseOptionalLSquare())) {
3413 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
3414 parser.parseRSquare())
3415 return failure();
3416 } else {
3417 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3418 parser.getContext(), mlir::acc::DeviceType::None));
3419 }
3420
3421 seg.push_back(gangOperands.size() - crtOperandsSize);
3422
3423 } while (succeeded(parser.parseOptionalComma()));
3424
3425 if (failed(parser.parseRParen()))
3426 return failure();
3427
3428 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
3429 gangArgTypeAttributes.end());
3430 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
3431 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
3432
3434 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3435 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
3436
3437 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
3438 return success();
3439}
3440
3442 mlir::OperandRange operands, mlir::TypeRange types,
3443 std::optional<mlir::ArrayAttr> gangArgTypes,
3444 std::optional<mlir::ArrayAttr> deviceTypes,
3445 std::optional<mlir::DenseI32ArrayAttr> segments,
3446 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3447
3448 if (operands.begin() == operands.end() &&
3449 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
3450 return;
3451 }
3452
3453 p << "(";
3454
3455 printDeviceTypes(p, gangOnlyDeviceTypes);
3456
3457 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
3458 hasDeviceTypeValues(deviceTypes))
3459 p << ", ";
3460
3461 if (hasDeviceTypeValues(deviceTypes)) {
3462 unsigned opIdx = 0;
3463 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
3464 p << "{";
3465 llvm::interleaveComma(
3466 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
3467 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3468 (*gangArgTypes)[opIdx]);
3469 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3470 p << LoopOp::getGangNumKeyword();
3471 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3472 p << LoopOp::getGangDimKeyword();
3473 else if (gangArgTypeAttr.getValue() ==
3474 mlir::acc::GangArgType::Static)
3475 p << LoopOp::getGangStaticKeyword();
3476 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
3477 ++opIdx;
3478 });
3479 p << "}";
3480 printSingleDeviceType(p, it.value());
3481 });
3482 }
3483 p << ")";
3484}
3485
3487 std::optional<mlir::ArrayAttr> segments,
3488 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3489 if (!segments)
3490 return false;
3491 for (auto attr : *segments) {
3492 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3493 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3494 return true;
3495 }
3496 return false;
3497}
3498
3499/// Check for duplicates in the DeviceType array attribute.
3500/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3501static std::optional<mlir::acc::DeviceType>
3502checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3503 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3504 if (!deviceTypes)
3505 return std::nullopt;
3506 for (auto attr : deviceTypes) {
3507 auto deviceTypeAttr =
3508 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3509 if (!deviceTypeAttr)
3510 return mlir::acc::DeviceType::None;
3511 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3512 return deviceTypeAttr.getValue();
3513 }
3514 return std::nullopt;
3515}
3516
3517LogicalResult acc::LoopOp::verify() {
3518 if (getUpperbound().size() != getStep().size())
3519 return emitError() << "number of upperbounds expected to be the same as "
3520 "number of steps";
3521
3522 if (getUpperbound().size() != getLowerbound().size())
3523 return emitError() << "number of upperbounds expected to be the same as "
3524 "number of lowerbounds";
3525
3526 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3527 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3528 return emitError() << "inclusiveUpperbound size is expected to be the same"
3529 << " as upperbound size";
3530
3531 // Check collapse
3532 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3533 return emitOpError() << "collapse device_type attr must be define when"
3534 << " collapse attr is present";
3535
3536 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3537 getCollapseAttr().getValue().size() !=
3538 getCollapseDeviceTypeAttr().getValue().size())
3539 return emitOpError() << "collapse attribute count must match collapse"
3540 << " device_type count";
3541 if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3542 return emitOpError() << "duplicate device_type `"
3543 << acc::stringifyDeviceType(*duplicateDeviceType)
3544 << "` found in collapseDeviceType attribute";
3545
3546 // Check gang
3547 if (!getGangOperands().empty()) {
3548 if (!getGangOperandsArgType())
3549 return emitOpError() << "gangOperandsArgType attribute must be defined"
3550 << " when gang operands are present";
3551
3552 if (getGangOperands().size() !=
3553 getGangOperandsArgTypeAttr().getValue().size())
3554 return emitOpError() << "gangOperandsArgType attribute count must match"
3555 << " gangOperands count";
3556 }
3557 if (getGangAttr()) {
3558 if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3559 return emitOpError() << "duplicate device_type `"
3560 << acc::stringifyDeviceType(*duplicateDeviceType)
3561 << "` found in gang attribute";
3562 }
3563
3565 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3566 getGangOperandsDeviceTypeAttr(), "gang")))
3567 return failure();
3568
3569 // Check worker
3570 if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3571 return emitOpError() << "duplicate device_type `"
3572 << acc::stringifyDeviceType(*duplicateDeviceType)
3573 << "` found in worker attribute";
3574 if (auto duplicateDeviceType =
3575 checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3576 return emitOpError() << "duplicate device_type `"
3577 << acc::stringifyDeviceType(*duplicateDeviceType)
3578 << "` found in workerNumOperandsDeviceType attribute";
3579 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3580 getWorkerNumOperandsDeviceTypeAttr(),
3581 "worker")))
3582 return failure();
3583
3584 // Check vector
3585 if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3586 return emitOpError() << "duplicate device_type `"
3587 << acc::stringifyDeviceType(*duplicateDeviceType)
3588 << "` found in vector attribute";
3589 if (auto duplicateDeviceType =
3590 checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3591 return emitOpError() << "duplicate device_type `"
3592 << acc::stringifyDeviceType(*duplicateDeviceType)
3593 << "` found in vectorOperandsDeviceType attribute";
3594 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3595 getVectorOperandsDeviceTypeAttr(),
3596 "vector")))
3597 return failure();
3598
3600 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3601 getTileOperandsDeviceTypeAttr(), "tile")))
3602 return failure();
3603
3604 // auto, independent and seq attribute are mutually exclusive.
3605 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3606 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3607 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3608 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3609 return emitError() << "only one of auto, independent, seq can be present "
3610 "at the same time";
3611 }
3612
3613 // Check that at least one of auto, independent, or seq is present
3614 // for the device-independent default clauses.
3615 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3616 return attr.getValue() == mlir::acc::DeviceType::None;
3617 };
3618 bool hasDefaultSeq =
3619 getSeqAttr()
3620 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3621 hasDeviceNone)
3622 : false;
3623 bool hasDefaultIndependent =
3624 getIndependentAttr()
3625 ? llvm::any_of(
3626 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3627 hasDeviceNone)
3628 : false;
3629 bool hasDefaultAuto =
3630 getAuto_Attr()
3631 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3632 hasDeviceNone)
3633 : false;
3634 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3635 return emitError()
3636 << "at least one of auto, independent, seq must be present";
3637 }
3638
3639 // Gang, worker and vector are incompatible with seq.
3640 if (getSeqAttr()) {
3641 for (auto attr : getSeqAttr()) {
3642 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3643 if (hasVector(deviceTypeAttr.getValue()) ||
3644 getVectorValue(deviceTypeAttr.getValue()) ||
3645 hasWorker(deviceTypeAttr.getValue()) ||
3646 getWorkerValue(deviceTypeAttr.getValue()) ||
3647 hasGang(deviceTypeAttr.getValue()) ||
3648 getGangValue(mlir::acc::GangArgType::Num,
3649 deviceTypeAttr.getValue()) ||
3650 getGangValue(mlir::acc::GangArgType::Dim,
3651 deviceTypeAttr.getValue()) ||
3652 getGangValue(mlir::acc::GangArgType::Static,
3653 deviceTypeAttr.getValue()))
3654 return emitError() << "gang, worker or vector cannot appear with seq";
3655 }
3656 }
3657
3658 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
3659 mlir::acc::PrivateRecipeOp>(
3660 *this, getPrivateOperands(), "private")))
3661 return failure();
3662
3663 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
3664 mlir::acc::FirstprivateRecipeOp>(
3665 *this, getFirstprivateOperands(), "firstprivate")))
3666 return failure();
3667
3668 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
3669 mlir::acc::ReductionRecipeOp>(
3670 *this, getReductionOperands(), "reduction")))
3671 return failure();
3672
3673 if (getCombined().has_value() &&
3674 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3675 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3676 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3677 return emitError("unexpected combined constructs attribute");
3678 }
3679
3680 // Check non-empty body().
3681 if (getRegion().empty())
3682 return emitError("expected non-empty body.");
3683
3684 if (getUnstructured()) {
3685 if (!isContainerLike())
3686 return emitError(
3687 "unstructured acc.loop must not have induction variables");
3688 } else if (isContainerLike()) {
3689 // When it is container-like - it is expected to hold a loop-like operation.
3690 // Obtain the maximum collapse count - we use this to check that there
3691 // are enough loops contained.
3692 uint64_t collapseCount = getCollapseValue().value_or(1);
3693 if (getCollapseAttr()) {
3694 for (auto collapseEntry : getCollapseAttr()) {
3695 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3696 if (intAttr.getValue().getZExtValue() > collapseCount)
3697 collapseCount = intAttr.getValue().getZExtValue();
3698 }
3699 }
3700
3701 // We want to check that we find enough loop-like operations inside.
3702 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3703 // level.
3704 mlir::Operation *expectedParent = this->getOperation();
3705 bool foundSibling = false;
3706 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3707 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3708 // This effectively checks that we are not looking at a sibling loop.
3709 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3710 expectedParent) {
3711 foundSibling = true;
3713 }
3714
3715 collapseCount--;
3716 expectedParent = op;
3717 }
3718 // We found enough contained loops.
3719 if (collapseCount == 0)
3722 });
3723
3724 if (foundSibling)
3725 return emitError("found sibling loops inside container-like acc.loop");
3726 if (collapseCount != 0)
3727 return emitError("failed to find enough loop-like operations inside "
3728 "container-like acc.loop");
3729 }
3730
3731 return success();
3732}
3733
3734unsigned LoopOp::getNumDataOperands() {
3735 return getReductionOperands().size() + getPrivateOperands().size() +
3736 getFirstprivateOperands().size();
3737}
3738
3739Value LoopOp::getDataOperand(unsigned i) {
3740 unsigned numOptional =
3741 getLowerbound().size() + getUpperbound().size() + getStep().size();
3742 numOptional += getGangOperands().size();
3743 numOptional += getVectorOperands().size();
3744 numOptional += getWorkerNumOperands().size();
3745 numOptional += getTileOperands().size();
3746 numOptional += getCacheOperands().size();
3747 return getOperand(numOptional + i);
3748}
3749
3750bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3751
3752bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3753 return hasDeviceType(getAuto_(), deviceType);
3754}
3755
3756bool LoopOp::hasIndependent() {
3757 return hasIndependent(mlir::acc::DeviceType::None);
3758}
3759
3760bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3761 return hasDeviceType(getIndependent(), deviceType);
3762}
3763
3764bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3765
3766bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3767 return hasDeviceType(getSeq(), deviceType);
3768}
3769
3770mlir::Value LoopOp::getVectorValue() {
3771 return getVectorValue(mlir::acc::DeviceType::None);
3772}
3773
3774mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3775 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3776 getVectorOperands(), deviceType);
3777}
3778
3779bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3780
3781bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3782 return hasDeviceType(getVector(), deviceType);
3783}
3784
3785mlir::Value LoopOp::getWorkerValue() {
3786 return getWorkerValue(mlir::acc::DeviceType::None);
3787}
3788
3789mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3790 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3791 getWorkerNumOperands(), deviceType);
3792}
3793
3794bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3795
3796bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3797 return hasDeviceType(getWorker(), deviceType);
3798}
3799
3800mlir::Operation::operand_range LoopOp::getTileValues() {
3801 return getTileValues(mlir::acc::DeviceType::None);
3802}
3803
3805LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3806 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3807 getTileOperandsSegments(), deviceType);
3808}
3809
3810std::optional<int64_t> LoopOp::getCollapseValue() {
3811 return getCollapseValue(mlir::acc::DeviceType::None);
3812}
3813
3814std::optional<int64_t>
3815LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3816 if (!getCollapseAttr())
3817 return std::nullopt;
3818 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3819 auto intAttr =
3820 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3821 return intAttr.getValue().getZExtValue();
3822 }
3823 return std::nullopt;
3824}
3825
3826mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3827 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3828}
3829
3830mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3831 mlir::acc::DeviceType deviceType) {
3832 if (getGangOperands().empty())
3833 return {};
3834 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3835 int32_t nbOperandsBefore = 0;
3836 for (unsigned i = 0; i < *pos; ++i)
3837 nbOperandsBefore += (*getGangOperandsSegments())[i];
3839 getGangOperands()
3840 .drop_front(nbOperandsBefore)
3841 .take_front((*getGangOperandsSegments())[*pos]);
3842
3843 int32_t argTypeIdx = nbOperandsBefore;
3844 for (auto value : values) {
3845 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3846 (*getGangOperandsArgType())[argTypeIdx]);
3847 if (gangArgTypeAttr.getValue() == gangArgType)
3848 return value;
3849 ++argTypeIdx;
3850 }
3851 }
3852 return {};
3853}
3854
3855bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3856
3857bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3858 return hasDeviceType(getGang(), deviceType);
3859}
3860
3861llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3862 return {&getRegion()};
3863}
3864
3865/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3866/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3867/// `(` ssa-id-and-type-list `)`
3868/// region
3869ParseResult
3872 SmallVectorImpl<Type> &lowerboundType,
3874 SmallVectorImpl<Type> &upperboundType,
3876 SmallVectorImpl<Type> &stepType) {
3877
3879 if (succeeded(
3880 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3881 if (parser.parseLParen() ||
3882 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3883 /*allowType=*/true) ||
3884 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3885 parser.parseOperandList(lowerbound, inductionVars.size(),
3887 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3888 parser.parseKeyword("to") || parser.parseLParen() ||
3889 parser.parseOperandList(upperbound, inductionVars.size(),
3891 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3892 parser.parseKeyword("step") || parser.parseLParen() ||
3893 parser.parseOperandList(step, inductionVars.size(),
3895 parser.parseColonTypeList(stepType) || parser.parseRParen())
3896 return failure();
3897 }
3898 return parser.parseRegion(region, inductionVars);
3899}
3900
3902 ValueRange lowerbound, TypeRange lowerboundType,
3903 ValueRange upperbound, TypeRange upperboundType,
3904 ValueRange steps, TypeRange stepType) {
3905 ValueRange regionArgs = region.front().getArguments();
3906 if (!regionArgs.empty()) {
3907 p << acc::LoopOp::getControlKeyword() << "(";
3908 llvm::interleaveComma(regionArgs, p,
3909 [&p](Value v) { p << v << " : " << v.getType(); });
3910 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3911 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3912 << " : " << stepType << ") ";
3913 }
3914 p.printRegion(region, /*printEntryBlockArgs=*/false);
3915}
3916
3917void acc::LoopOp::addSeq(MLIRContext *context,
3918 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3919 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3920 effectiveDeviceTypes));
3921}
3922
3923void acc::LoopOp::addIndependent(
3924 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3925 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3926 context, getIndependentAttr(), effectiveDeviceTypes));
3927}
3928
3929void acc::LoopOp::addAuto(MLIRContext *context,
3930 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3931 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3932 effectiveDeviceTypes));
3933}
3934
3935void acc::LoopOp::setCollapseForDeviceTypes(
3936 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3937 llvm::APInt value) {
3940
3941 assert((getCollapseAttr() == nullptr) ==
3942 (getCollapseDeviceTypeAttr() == nullptr));
3943 assert(value.getBitWidth() == 64);
3944
3945 if (getCollapseAttr()) {
3946 for (const auto &existing :
3947 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3948 newValues.push_back(std::get<0>(existing));
3949 newDeviceTypes.push_back(std::get<1>(existing));
3950 }
3951 }
3952
3953 if (effectiveDeviceTypes.empty()) {
3954 // If the effective device-types list is empty, this is before there are any
3955 // being applied by device_type, so this should be added as a 'none'.
3956 newValues.push_back(
3957 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3958 newDeviceTypes.push_back(
3959 acc::DeviceTypeAttr::get(context, DeviceType::None));
3960 } else {
3961 for (DeviceType dt : effectiveDeviceTypes) {
3962 newValues.push_back(
3963 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3964 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3965 }
3966 }
3967
3968 setCollapseAttr(ArrayAttr::get(context, newValues));
3969 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3970}
3971
3972void acc::LoopOp::setTileForDeviceTypes(
3973 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3974 ValueRange values) {
3976 if (getTileOperandsSegments())
3977 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3978
3979 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3980 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3981 getTileOperandsMutable(), segments));
3982
3983 setTileOperandsSegments(segments);
3984}
3985
3986void acc::LoopOp::addVectorOperand(
3987 MLIRContext *context, mlir::Value newValue,
3988 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3989 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3990 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3991 newValue, getVectorOperandsMutable()));
3992}
3993
3994void acc::LoopOp::addEmptyVector(
3995 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3996 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3997 effectiveDeviceTypes));
3998}
3999
4000void acc::LoopOp::addWorkerNumOperand(
4001 MLIRContext *context, mlir::Value newValue,
4002 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4003 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4004 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
4005 newValue, getWorkerNumOperandsMutable()));
4006}
4007
4008void acc::LoopOp::addEmptyWorker(
4009 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4010 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4011 effectiveDeviceTypes));
4012}
4013
4014void acc::LoopOp::addEmptyGang(
4015 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4016 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4017 effectiveDeviceTypes));
4018}
4019
4020bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
4021 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
4022 return attr.getValue() == dt;
4023 };
4024 auto testFromArr = [=](ArrayAttr arr) -> bool {
4025 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
4026 };
4027
4028 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
4029 return true;
4030 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
4031 return true;
4032 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
4033 return true;
4034
4035 return false;
4036}
4037
4038bool acc::LoopOp::hasDefaultGangWorkerVector() {
4039 return hasAnyGangWorkerVector(DeviceType::None);
4040}
4041
4042bool acc::LoopOp::hasAnyGangWorkerVector(DeviceType deviceType) {
4043 return hasVector(deviceType) || getVectorValue(deviceType) ||
4044 hasWorker(deviceType) || getWorkerValue(deviceType) ||
4045 hasGang(deviceType) || getGangValue(GangArgType::Num, deviceType) ||
4046 getGangValue(GangArgType::Dim, deviceType) ||
4047 getGangValue(GangArgType::Static, deviceType);
4048}
4049
4050acc::LoopParMode
4051acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
4052 if (hasSeq(deviceType))
4053 return LoopParMode::loop_seq;
4054 if (hasAuto(deviceType))
4055 return LoopParMode::loop_auto;
4056 if (hasIndependent(deviceType))
4057 return LoopParMode::loop_independent;
4058 if (hasSeq())
4059 return LoopParMode::loop_seq;
4060 if (hasAuto())
4061 return LoopParMode::loop_auto;
4062 assert(hasIndependent() &&
4063 "loop must have default auto, seq, or independent");
4064 return LoopParMode::loop_independent;
4065}
4066
4067void acc::LoopOp::addGangOperands(
4068 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4071 if (std::optional<ArrayRef<int32_t>> existingSegments =
4072 getGangOperandsSegments())
4073 llvm::copy(*existingSegments, std::back_inserter(segments));
4074
4075 unsigned beforeCount = segments.size();
4076
4077 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4078 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
4079 getGangOperandsMutable(), segments));
4080
4081 setGangOperandsSegments(segments);
4082
4083 // This is a bit of extra work to make sure we update the 'types' correctly by
4084 // adding to the types collection the correct number of times. We could
4085 // potentially add something similar to the
4086 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
4087 // excessive for a one-off case.
4088 unsigned numAdded = segments.size() - beforeCount;
4089
4090 if (numAdded > 0) {
4092 if (getGangOperandsArgTypeAttr())
4093 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
4094
4095 for (auto i : llvm::index_range(0u, numAdded)) {
4096 llvm::transform(argTypes, std::back_inserter(gangTypes),
4097 [=](mlir::acc::GangArgType gangTy) {
4098 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
4099 });
4100 (void)i;
4101 }
4102
4103 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
4104 }
4105}
4106
4107void acc::LoopOp::addPrivatization(MLIRContext *context,
4108 mlir::acc::PrivateOp op,
4109 mlir::acc::PrivateRecipeOp recipe) {
4110 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4111 getPrivateOperandsMutable().append(op.getResult());
4112}
4113
4114void acc::LoopOp::addFirstPrivatization(
4115 MLIRContext *context, mlir::acc::FirstprivateOp op,
4116 mlir::acc::FirstprivateRecipeOp recipe) {
4117 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4118 getFirstprivateOperandsMutable().append(op.getResult());
4119}
4120
4121void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
4122 mlir::acc::ReductionRecipeOp recipe) {
4123 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4124 getReductionOperandsMutable().append(op.getResult());
4125}
4126
4127//===----------------------------------------------------------------------===//
4128// DataOp
4129//===----------------------------------------------------------------------===//
4130
4131LogicalResult acc::DataOp::verify() {
4132 // 2.6.5. Data Construct restriction
4133 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
4134 // attach, or default clause must appear on a data construct.
4135 if (getOperands().empty() && !getDefaultAttr())
4136 return emitError("at least one operand or the default attribute "
4137 "must appear on the data operation");
4138
4139 for (mlir::Value operand : getDataClauseOperands())
4140 if (isa<BlockArgument>(operand) ||
4141 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4142 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
4143 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
4144 operand.getDefiningOp()))
4145 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4146 "as defining op");
4147
4149 return failure();
4150
4151 return success();
4152}
4153
4154unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
4155
4156Value DataOp::getDataOperand(unsigned i) {
4157 unsigned numOptional = getIfCond() ? 1 : 0;
4158 numOptional += getAsyncOperands().size() ? 1 : 0;
4159 numOptional += getWaitOperands().size();
4160 return getOperand(numOptional + i);
4161}
4162
4163bool acc::DataOp::hasAsyncOnly() {
4164 return hasAsyncOnly(mlir::acc::DeviceType::None);
4165}
4166
4167bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4168 return hasDeviceType(getAsyncOnly(), deviceType);
4169}
4170
4171mlir::Value DataOp::getAsyncValue() {
4172 return getAsyncValue(mlir::acc::DeviceType::None);
4173}
4174
4175mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4177 getAsyncOperands(), deviceType);
4178}
4179
4180bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
4181
4182bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4183 return hasDeviceType(getWaitOnly(), deviceType);
4184}
4185
4186mlir::Operation::operand_range DataOp::getWaitValues() {
4187 return getWaitValues(mlir::acc::DeviceType::None);
4188}
4189
4191DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4193 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4194 getHasWaitDevnum(), deviceType);
4195}
4196
4197mlir::Value DataOp::getWaitDevnum() {
4198 return getWaitDevnum(mlir::acc::DeviceType::None);
4199}
4200
4201mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4202 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4203 getWaitOperandsSegments(), getHasWaitDevnum(),
4204 deviceType);
4205}
4206
4207void acc::DataOp::addAsyncOnly(
4208 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4209 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4210 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4211}
4212
4213void acc::DataOp::addAsyncOperand(
4214 MLIRContext *context, mlir::Value newValue,
4215 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4216 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4217 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4218 getAsyncOperandsMutable()));
4219}
4220
4221void acc::DataOp::addWaitOnly(MLIRContext *context,
4222 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4223 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4224 effectiveDeviceTypes));
4225}
4226
4227void acc::DataOp::addWaitOperands(
4228 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4229 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4230
4232 if (getWaitOperandsSegments())
4233 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4234
4235 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4236 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4237 getWaitOperandsMutable(), segments));
4238 setWaitOperandsSegments(segments);
4239
4241 if (getHasWaitDevnumAttr())
4242 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4243 hasDevnums.insert(
4244 hasDevnums.end(),
4245 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4246 mlir::BoolAttr::get(context, hasDevnum));
4247 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4248}
4249
4250//===----------------------------------------------------------------------===//
4251// ExitDataOp
4252//===----------------------------------------------------------------------===//
4253
4254LogicalResult acc::ExitDataOp::verify() {
4255 // 2.6.6. Data Exit Directive restriction
4256 // At least one copyout, delete, or detach clause must appear on an exit data
4257 // directive.
4258 if (getDataClauseOperands().empty())
4259 return emitError("at least one operand must be present in dataOperands on "
4260 "the exit data operation");
4261
4262 // The async attribute represent the async clause without value. Therefore the
4263 // attribute and operand cannot appear at the same time.
4264 if (getAsyncOperand() && getAsync())
4265 return emitError("async attribute cannot appear with asyncOperand");
4266
4267 // The wait attribute represent the wait clause without values. Therefore the
4268 // attribute and operands cannot appear at the same time.
4269 if (!getWaitOperands().empty() && getWait())
4270 return emitError("wait attribute cannot appear with waitOperands");
4271
4272 if (getWaitDevnum() && getWaitOperands().empty())
4273 return emitError("wait_devnum cannot appear without waitOperands");
4274
4275 return success();
4276}
4277
4278unsigned ExitDataOp::getNumDataOperands() {
4279 return getDataClauseOperands().size();
4280}
4281
4282Value ExitDataOp::getDataOperand(unsigned i) {
4283 unsigned numOptional = getIfCond() ? 1 : 0;
4284 numOptional += getAsyncOperand() ? 1 : 0;
4285 numOptional += getWaitDevnum() ? 1 : 0;
4286 return getOperand(getWaitOperands().size() + numOptional + i);
4287}
4288
4289void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4290 MLIRContext *context) {
4291 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
4292}
4293
4294void ExitDataOp::addAsyncOnly(MLIRContext *context,
4295 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4296 assert(effectiveDeviceTypes.empty());
4297 assert(!getAsyncAttr());
4298 assert(!getAsyncOperand());
4299
4300 setAsyncAttr(mlir::UnitAttr::get(context));
4301}
4302
4303void ExitDataOp::addAsyncOperand(
4304 MLIRContext *context, mlir::Value newValue,
4305 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4306 assert(effectiveDeviceTypes.empty());
4307 assert(!getAsyncAttr());
4308 assert(!getAsyncOperand());
4309
4310 getAsyncOperandMutable().append(newValue);
4311}
4312
4313void ExitDataOp::addWaitOnly(MLIRContext *context,
4314 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4315 assert(effectiveDeviceTypes.empty());
4316 assert(!getWaitAttr());
4317 assert(getWaitOperands().empty());
4318 assert(!getWaitDevnum());
4319
4320 setWaitAttr(mlir::UnitAttr::get(context));
4321}
4322
4323void ExitDataOp::addWaitOperands(
4324 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4325 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4326 assert(effectiveDeviceTypes.empty());
4327 assert(!getWaitAttr());
4328 assert(getWaitOperands().empty());
4329 assert(!getWaitDevnum());
4330
4331 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4332 // operands list.
4333 if (hasDevnum) {
4334 getWaitDevnumMutable().append(newValues.front());
4335 newValues = newValues.drop_front();
4336 }
4337
4338 getWaitOperandsMutable().append(newValues);
4339}
4340
4341//===----------------------------------------------------------------------===//
4342// EnterDataOp
4343//===----------------------------------------------------------------------===//
4344
4345LogicalResult acc::EnterDataOp::verify() {
4346 // 2.6.6. Data Enter Directive restriction
4347 // At least one copyin, create, or attach clause must appear on an enter data
4348 // directive.
4349 if (getDataClauseOperands().empty())
4350 return emitError("at least one operand must be present in dataOperands on "
4351 "the enter data operation");
4352
4353 // The async attribute represent the async clause without value. Therefore the
4354 // attribute and operand cannot appear at the same time.
4355 if (getAsyncOperand() && getAsync())
4356 return emitError("async attribute cannot appear with asyncOperand");
4357
4358 // The wait attribute represent the wait clause without values. Therefore the
4359 // attribute and operands cannot appear at the same time.
4360 if (!getWaitOperands().empty() && getWait())
4361 return emitError("wait attribute cannot appear with waitOperands");
4362
4363 if (getWaitDevnum() && getWaitOperands().empty())
4364 return emitError("wait_devnum cannot appear without waitOperands");
4365
4366 for (mlir::Value operand : getDataClauseOperands())
4367 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4368 operand.getDefiningOp()))
4369 return emitError("expect data entry operation as defining op");
4370
4371 return success();
4372}
4373
4374unsigned EnterDataOp::getNumDataOperands() {
4375 return getDataClauseOperands().size();
4376}
4377
4378Value EnterDataOp::getDataOperand(unsigned i) {
4379 unsigned numOptional = getIfCond() ? 1 : 0;
4380 numOptional += getAsyncOperand() ? 1 : 0;
4381 numOptional += getWaitDevnum() ? 1 : 0;
4382 return getOperand(getWaitOperands().size() + numOptional + i);
4383}
4384
4385void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4386 MLIRContext *context) {
4387 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
4388}
4389
4390void EnterDataOp::addAsyncOnly(
4391 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4392 assert(effectiveDeviceTypes.empty());
4393 assert(!getAsyncAttr());
4394 assert(!getAsyncOperand());
4395
4396 setAsyncAttr(mlir::UnitAttr::get(context));
4397}
4398
4399void EnterDataOp::addAsyncOperand(
4400 MLIRContext *context, mlir::Value newValue,
4401 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4402 assert(effectiveDeviceTypes.empty());
4403 assert(!getAsyncAttr());
4404 assert(!getAsyncOperand());
4405
4406 getAsyncOperandMutable().append(newValue);
4407}
4408
4409void EnterDataOp::addWaitOnly(MLIRContext *context,
4410 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4411 assert(effectiveDeviceTypes.empty());
4412 assert(!getWaitAttr());
4413 assert(getWaitOperands().empty());
4414 assert(!getWaitDevnum());
4415
4416 setWaitAttr(mlir::UnitAttr::get(context));
4417}
4418
4419void EnterDataOp::addWaitOperands(
4420 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4421 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4422 assert(effectiveDeviceTypes.empty());
4423 assert(!getWaitAttr());
4424 assert(getWaitOperands().empty());
4425 assert(!getWaitDevnum());
4426
4427 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4428 // operands list.
4429 if (hasDevnum) {
4430 getWaitDevnumMutable().append(newValues.front());
4431 newValues = newValues.drop_front();
4432 }
4433
4434 getWaitOperandsMutable().append(newValues);
4435}
4436
4437//===----------------------------------------------------------------------===//
4438// AtomicReadOp
4439//===----------------------------------------------------------------------===//
4440
4441LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
4442
4443//===----------------------------------------------------------------------===//
4444// AtomicWriteOp
4445//===----------------------------------------------------------------------===//
4446
4447LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
4448
4449//===----------------------------------------------------------------------===//
4450// AtomicUpdateOp
4451//===----------------------------------------------------------------------===//
4452
4453LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4454 PatternRewriter &rewriter) {
4455 if (op.isNoOp()) {
4456 rewriter.eraseOp(op);
4457 return success();
4458 }
4459
4460 if (Value writeVal = op.getWriteOpVal()) {
4461 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
4462 op.getIfCond());
4463 return success();
4464 }
4465
4466 return failure();
4467}
4468
4469LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
4470
4471LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4472
4473//===----------------------------------------------------------------------===//
4474// AtomicCaptureOp
4475//===----------------------------------------------------------------------===//
4476
4477AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4478 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4479 return op;
4480 return dyn_cast<AtomicReadOp>(getSecondOp());
4481}
4482
4483AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4484 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4485 return op;
4486 return dyn_cast<AtomicWriteOp>(getSecondOp());
4487}
4488
4489AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4490 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4491 return op;
4492 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4493}
4494
4495LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
4496
4497//===----------------------------------------------------------------------===//
4498// DeclareEnterOp
4499//===----------------------------------------------------------------------===//
4500
4501template <typename Op>
4502static LogicalResult
4504 bool requireAtLeastOneOperand = true) {
4505 if (operands.empty() && requireAtLeastOneOperand)
4506 return emitError(
4507 op->getLoc(),
4508 "at least one operand must appear on the declare operation");
4509
4510 for (mlir::Value operand : operands) {
4511 if (isa<BlockArgument>(operand) ||
4512 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4513 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4514 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4515 operand.getDefiningOp()))
4516 return op.emitError(
4517 "expect valid declare data entry operation or acc.getdeviceptr "
4518 "as defining op");
4519
4520 mlir::Value var{getVar(operand.getDefiningOp())};
4521 assert(var && "declare operands can only be data entry operations which "
4522 "must have var");
4523 (void)var;
4524 std::optional<mlir::acc::DataClause> dataClauseOptional{
4525 getDataClause(operand.getDefiningOp())};
4526 assert(dataClauseOptional.has_value() &&
4527 "declare operands can only be data entry operations which must have "
4528 "dataClause");
4529 (void)dataClauseOptional;
4530 }
4531
4532 return success();
4533}
4534
4535LogicalResult acc::DeclareEnterOp::verify() {
4536 return checkDeclareOperands(*this, this->getDataClauseOperands());
4537}
4538
4539//===----------------------------------------------------------------------===//
4540// DeclareExitOp
4541//===----------------------------------------------------------------------===//
4542
4543LogicalResult acc::DeclareExitOp::verify() {
4544 if (getToken())
4545 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4546 /*requireAtLeastOneOperand=*/false);
4547 return checkDeclareOperands(*this, this->getDataClauseOperands());
4548}
4549
4550//===----------------------------------------------------------------------===//
4551// DeclareOp
4552//===----------------------------------------------------------------------===//
4553
4554LogicalResult acc::DeclareOp::verify() {
4555 return checkDeclareOperands(*this, this->getDataClauseOperands());
4556}
4557
4558//===----------------------------------------------------------------------===//
4559// RoutineOp
4560//===----------------------------------------------------------------------===//
4561
4562static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4563 acc::DeviceType dtype) {
4564 unsigned parallelism = 0;
4565 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4566 parallelism += op.hasWorker(dtype) ? 1 : 0;
4567 parallelism += op.hasVector(dtype) ? 1 : 0;
4568 parallelism += op.hasSeq(dtype) ? 1 : 0;
4569 return parallelism;
4570}
4571
4572LogicalResult acc::RoutineOp::verify() {
4573 unsigned baseParallelism =
4574 getParallelismForDeviceType(*this, acc::DeviceType::None);
4575
4576 if (baseParallelism > 1)
4577 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4578 "be present at the same time";
4579
4580 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4581 ++dtypeInt) {
4582 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4583 if (dtype == acc::DeviceType::None)
4584 continue;
4585 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4586
4587 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4588 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4589 "be present at the same time for device_type `"
4590 << acc::stringifyDeviceType(dtype) << "`";
4591 }
4592
4593 return success();
4594}
4595
4596static ParseResult parseBindName(OpAsmParser &parser,
4597 mlir::ArrayAttr &bindIdName,
4598 mlir::ArrayAttr &bindStrName,
4599 mlir::ArrayAttr &deviceIdTypes,
4600 mlir::ArrayAttr &deviceStrTypes) {
4601 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4602 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4603 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4604 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4605
4606 if (failed(parser.parseCommaSeparatedList([&]() {
4607 mlir::Attribute newAttr;
4608 bool isSymbolRefAttr;
4609 auto parseResult = parser.parseAttribute(newAttr);
4610 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4611 bindIdNameAttrs.push_back(symbolRefAttr);
4612 isSymbolRefAttr = true;
4613 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4614 bindStrNameAttrs.push_back(stringAttr);
4615 isSymbolRefAttr = false;
4616 }
4617 if (parseResult)
4618 return failure();
4619 if (failed(parser.parseOptionalLSquare())) {
4620 if (isSymbolRefAttr) {
4621 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4622 parser.getContext(), mlir::acc::DeviceType::None));
4623 } else {
4624 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4625 parser.getContext(), mlir::acc::DeviceType::None));
4626 }
4627 } else {
4628 if (isSymbolRefAttr) {
4629 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4630 parser.parseRSquare())
4631 return failure();
4632 } else {
4633 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4634 parser.parseRSquare())
4635 return failure();
4636 }
4637 }
4638 return success();
4639 })))
4640 return failure();
4641
4642 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4643 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4644 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4645 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4646
4647 return success();
4648}
4649
4651 std::optional<mlir::ArrayAttr> bindIdName,
4652 std::optional<mlir::ArrayAttr> bindStrName,
4653 std::optional<mlir::ArrayAttr> deviceIdTypes,
4654 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4655 // Create combined vectors for all bind names and device types
4658
4659 // Append bindIdName and deviceIdTypes
4660 if (hasDeviceTypeValues(deviceIdTypes)) {
4661 allBindNames.append(bindIdName->begin(), bindIdName->end());
4662 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4663 }
4664
4665 // Append bindStrName and deviceStrTypes
4666 if (hasDeviceTypeValues(deviceStrTypes)) {
4667 allBindNames.append(bindStrName->begin(), bindStrName->end());
4668 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4669 }
4670
4671 // Print the combined sequence
4672 if (!allBindNames.empty())
4673 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4674 [&](const auto &pair) {
4675 p << std::get<0>(pair);
4676 printSingleDeviceType(p, std::get<1>(pair));
4677 });
4678}
4679
4680static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4681 mlir::ArrayAttr &gang,
4682 mlir::ArrayAttr &gangDim,
4683 mlir::ArrayAttr &gangDimDeviceTypes) {
4684
4685 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4686 gangDimDeviceTypeAttrs;
4687 bool needCommaBeforeOperands = false;
4688
4689 // Gang keyword only
4690 if (failed(parser.parseOptionalLParen())) {
4691 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4692 parser.getContext(), mlir::acc::DeviceType::None));
4693 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4694 return success();
4695 }
4696
4697 // Parse keyword only attributes
4698 if (succeeded(parser.parseOptionalLSquare())) {
4699 if (failed(parser.parseCommaSeparatedList([&]() {
4700 if (parser.parseAttribute(gangAttrs.emplace_back()))
4701 return failure();
4702 return success();
4703 })))
4704 return failure();
4705 if (parser.parseRSquare())
4706 return failure();
4707 needCommaBeforeOperands = true;
4708 }
4709
4710 if (needCommaBeforeOperands && failed(parser.parseComma()))
4711 return failure();
4712
4713 if (failed(parser.parseCommaSeparatedList([&]() {
4714 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4715 parser.parseColon() ||
4716 parser.parseAttribute(gangDimAttrs.emplace_back()))
4717 return failure();
4718 if (succeeded(parser.parseOptionalLSquare())) {
4719 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4720 parser.parseRSquare())
4721 return failure();
4722 } else {
4723 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4724 parser.getContext(), mlir::acc::DeviceType::None));
4725 }
4726 return success();
4727 })))
4728 return failure();
4729
4730 if (failed(parser.parseRParen()))
4731 return failure();
4732
4733 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4734 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4735 gangDimDeviceTypes =
4736 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4737
4738 return success();
4739}
4740
4742 std::optional<mlir::ArrayAttr> gang,
4743 std::optional<mlir::ArrayAttr> gangDim,
4744 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4745
4746 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4747 gang->size() == 1) {
4748 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4749 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4750 return;
4751 }
4752
4753 p << "(";
4754
4755 printDeviceTypes(p, gang);
4756
4757 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4758 p << ", ";
4759
4760 if (hasDeviceTypeValues(gangDimDeviceTypes))
4761 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4762 [&](const auto &pair) {
4763 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4764 p << std::get<0>(pair);
4765 printSingleDeviceType(p, std::get<1>(pair));
4766 });
4767
4768 p << ")";
4769}
4770
4771static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4772 mlir::ArrayAttr &deviceTypes) {
4774 // Keyword only
4775 if (failed(parser.parseOptionalLParen())) {
4776 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4777 parser.getContext(), mlir::acc::DeviceType::None));
4778 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4779 return success();
4780 }
4781
4782 // Parse device type attributes
4783 if (succeeded(parser.parseOptionalLSquare())) {
4784 if (failed(parser.parseCommaSeparatedList([&]() {
4785 if (parser.parseAttribute(attributes.emplace_back()))
4786 return failure();
4787 return success();
4788 })))
4789 return failure();
4790 if (parser.parseRSquare() || parser.parseRParen())
4791 return failure();
4792 }
4793 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4794 return success();
4795}
4796
4797static void
4799 std::optional<mlir::ArrayAttr> deviceTypes) {
4800
4801 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4802 auto deviceTypeAttr =
4803 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4804 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4805 return;
4806 }
4807
4808 if (!hasDeviceTypeValues(deviceTypes))
4809 return;
4810
4811 p << "([";
4812 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4813 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4814 p << dTypeAttr;
4815 });
4816 p << "])";
4817}
4818
4819bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4820
4821bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4822 return hasDeviceType(getWorker(), deviceType);
4823}
4824
4825bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4826
4827bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4828 return hasDeviceType(getVector(), deviceType);
4829}
4830
4831bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4832
4833bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4834 return hasDeviceType(getSeq(), deviceType);
4835}
4836
4837std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4838RoutineOp::getBindNameValue() {
4839 return getBindNameValue(mlir::acc::DeviceType::None);
4840}
4841
4842std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4843RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4844 if (hasDeviceTypeValues(getBindIdNameDeviceType())) {
4845 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4846 auto attr = (*getBindIdName())[*pos];
4847 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4848 assert(symbolRefAttr && "expected SymbolRef");
4849 return symbolRefAttr;
4850 }
4851 }
4852
4853 if (hasDeviceTypeValues(getBindStrNameDeviceType())) {
4854 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4855 auto attr = (*getBindStrName())[*pos];
4856 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4857 assert(stringAttr && "expected String");
4858 return stringAttr;
4859 }
4860 }
4861
4862 return std::nullopt;
4863}
4864
4865bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4866
4867bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4868 return hasDeviceType(getGang(), deviceType);
4869}
4870
4871std::optional<int64_t> RoutineOp::getGangDimValue() {
4872 return getGangDimValue(mlir::acc::DeviceType::None);
4873}
4874
4875std::optional<int64_t>
4876RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4877 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4878 return std::nullopt;
4879 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4880 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4881 return intAttr.getInt();
4882 }
4883 return std::nullopt;
4884}
4885
4886void RoutineOp::addSeq(MLIRContext *context,
4887 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4888 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4889 effectiveDeviceTypes));
4890}
4891
4892void RoutineOp::addVector(MLIRContext *context,
4893 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4894 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4895 effectiveDeviceTypes));
4896}
4897
4898void RoutineOp::addWorker(MLIRContext *context,
4899 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4900 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4901 effectiveDeviceTypes));
4902}
4903
4904void RoutineOp::addGang(MLIRContext *context,
4905 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4906 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4907 effectiveDeviceTypes));
4908}
4909
4910void RoutineOp::addGang(MLIRContext *context,
4911 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4912 uint64_t val) {
4915
4916 if (getGangDimAttr())
4917 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4918 if (getGangDimDeviceTypeAttr())
4919 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4920
4921 assert(dimValues.size() == deviceTypes.size());
4922
4923 if (effectiveDeviceTypes.empty()) {
4924 dimValues.push_back(
4925 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4926 deviceTypes.push_back(
4927 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4928 } else {
4929 for (DeviceType dt : effectiveDeviceTypes) {
4930 dimValues.push_back(
4931 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4932 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4933 }
4934 }
4935 assert(dimValues.size() == deviceTypes.size());
4936
4937 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4938 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4939}
4940
4941void RoutineOp::addBindStrName(MLIRContext *context,
4942 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4943 mlir::StringAttr val) {
4944 unsigned before = getBindStrNameDeviceTypeAttr()
4945 ? getBindStrNameDeviceTypeAttr().size()
4946 : 0;
4947
4948 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4949 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4950 unsigned after = getBindStrNameDeviceTypeAttr().size();
4951
4953 if (getBindStrNameAttr())
4954 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4955 for (unsigned i = 0; i < after - before; ++i)
4956 vals.push_back(val);
4957
4958 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4959}
4960
4961void RoutineOp::addBindIDName(MLIRContext *context,
4962 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4963 mlir::SymbolRefAttr val) {
4964 unsigned before =
4965 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4966
4967 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4968 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4969 unsigned after = getBindIdNameDeviceTypeAttr().size();
4970
4972 if (getBindIdNameAttr())
4973 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4974 for (unsigned i = 0; i < after - before; ++i)
4975 vals.push_back(val);
4976
4977 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4978}
4979
4980//===----------------------------------------------------------------------===//
4981// InitOp
4982//===----------------------------------------------------------------------===//
4983
4984LogicalResult acc::InitOp::verify() {
4985 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4986 return emitOpError("cannot be nested in a compute operation");
4987 return success();
4988}
4989
4990void acc::InitOp::addDeviceType(MLIRContext *context,
4991 mlir::acc::DeviceType deviceType) {
4993 if (getDeviceTypesAttr())
4994 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4995
4996 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4997 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4998}
4999
5000//===----------------------------------------------------------------------===//
5001// ShutdownOp
5002//===----------------------------------------------------------------------===//
5003
5004LogicalResult acc::ShutdownOp::verify() {
5005 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
5006 return emitOpError("cannot be nested in a compute operation");
5007 return success();
5008}
5009
5010void acc::ShutdownOp::addDeviceType(MLIRContext *context,
5011 mlir::acc::DeviceType deviceType) {
5013 if (getDeviceTypesAttr())
5014 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
5015
5016 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
5017 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
5018}
5019
5020//===----------------------------------------------------------------------===//
5021// SetOp
5022//===----------------------------------------------------------------------===//
5023
5024LogicalResult acc::SetOp::verify() {
5025 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
5026 return emitOpError("cannot be nested in a compute operation");
5027 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
5028 return emitOpError("at least one default_async, device_num, or device_type "
5029 "operand must appear");
5030 return success();
5031}
5032
5033//===----------------------------------------------------------------------===//
5034// UpdateOp
5035//===----------------------------------------------------------------------===//
5036
5037LogicalResult acc::UpdateOp::verify() {
5038 // At least one of host or device should have a value.
5039 if (getDataClauseOperands().empty())
5040 return emitError("at least one value must be present in dataOperands");
5041
5043 getAsyncOperandsDeviceTypeAttr(),
5044 "async")))
5045 return failure();
5046
5048 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
5049 getWaitOperandsDeviceTypeAttr(), "wait")))
5050 return failure();
5051
5053 return failure();
5054
5055 for (mlir::Value operand : getDataClauseOperands())
5056 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
5057 operand.getDefiningOp()))
5058 return emitError("expect data entry/exit operation or acc.getdeviceptr "
5059 "as defining op");
5060
5061 return success();
5062}
5063
5064unsigned UpdateOp::getNumDataOperands() {
5065 return getDataClauseOperands().size();
5066}
5067
5068Value UpdateOp::getDataOperand(unsigned i) {
5069 unsigned numOptional = getAsyncOperands().size();
5070 numOptional += getIfCond() ? 1 : 0;
5071 return getOperand(getWaitOperands().size() + numOptional + i);
5072}
5073
5074void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
5075 MLIRContext *context) {
5076 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
5077}
5078
5079bool UpdateOp::hasAsyncOnly() {
5080 return hasAsyncOnly(mlir::acc::DeviceType::None);
5081}
5082
5083bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
5084 return hasDeviceType(getAsyncOnly(), deviceType);
5085}
5086
5087mlir::Value UpdateOp::getAsyncValue() {
5088 return getAsyncValue(mlir::acc::DeviceType::None);
5089}
5090
5091mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
5093 return {};
5094
5095 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
5096 return getAsyncOperands()[*pos];
5097
5098 return {};
5099}
5100
5101bool UpdateOp::hasWaitOnly() {
5102 return hasWaitOnly(mlir::acc::DeviceType::None);
5103}
5104
5105bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
5106 return hasDeviceType(getWaitOnly(), deviceType);
5107}
5108
5109mlir::Operation::operand_range UpdateOp::getWaitValues() {
5110 return getWaitValues(mlir::acc::DeviceType::None);
5111}
5112
5114UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
5116 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
5117 getHasWaitDevnum(), deviceType);
5118}
5119
5120mlir::Value UpdateOp::getWaitDevnum() {
5121 return getWaitDevnum(mlir::acc::DeviceType::None);
5122}
5123
5124mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
5125 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
5126 getWaitOperandsSegments(), getHasWaitDevnum(),
5127 deviceType);
5128}
5129
5130void UpdateOp::addAsyncOnly(MLIRContext *context,
5131 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5132 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
5133 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
5134}
5135
5136void UpdateOp::addAsyncOperand(
5137 MLIRContext *context, mlir::Value newValue,
5138 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5139 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5140 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
5141 getAsyncOperandsMutable()));
5142}
5143
5144void UpdateOp::addWaitOnly(MLIRContext *context,
5145 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5146 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
5147 effectiveDeviceTypes));
5148}
5149
5150void UpdateOp::addWaitOperands(
5151 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
5152 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5153
5155 if (getWaitOperandsSegments())
5156 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
5157
5158 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5159 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
5160 getWaitOperandsMutable(), segments));
5161 setWaitOperandsSegments(segments);
5162
5164 if (getHasWaitDevnumAttr())
5165 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5166 hasDevnums.insert(
5167 hasDevnums.end(),
5168 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
5169 mlir::BoolAttr::get(context, hasDevnum));
5170 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5171}
5172
5173//===----------------------------------------------------------------------===//
5174// WaitOp
5175//===----------------------------------------------------------------------===//
5176
5177LogicalResult acc::WaitOp::verify() {
5178 // The async attribute represent the async clause without value. Therefore the
5179 // attribute and operand cannot appear at the same time.
5180 if (getAsyncOperand() && getAsync())
5181 return emitError("async attribute cannot appear with asyncOperand");
5182
5183 if (getWaitDevnum() && getWaitOperands().empty())
5184 return emitError("wait_devnum cannot appear without waitOperands");
5185
5186 return success();
5187}
5188
5189#define GET_OP_CLASSES
5190#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5191
5192#define GET_ATTRDEF_CLASSES
5193#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5194
5195#define GET_TYPEDEF_CLASSES
5196#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5197
5198//===----------------------------------------------------------------------===//
5199// acc dialect utilities
5200//===----------------------------------------------------------------------===//
5201
5204 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
5206 accDataClauseOp)
5207 .Case<ACC_DATA_ENTRY_OPS>(
5208 [&](auto entry) { return entry.getVarPtr(); })
5209 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5210 [&](auto exit) { return exit.getVarPtr(); })
5211 .Default([&](mlir::Operation *) {
5213 })};
5214 return varPtr;
5215}
5216
5218 auto varPtr{
5220 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
5221 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5222 return varPtr;
5223}
5224
5226 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
5227 .Case<ACC_DATA_ENTRY_OPS>(
5228 [&](auto entry) { return entry.getVarType(); })
5229 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5230 [&](auto exit) { return exit.getVarType(); })
5231 .Default([&](mlir::Operation *) { return mlir::Type(); })};
5232 return varType;
5233}
5234
5237 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
5239 accDataClauseOp)
5240 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5241 [&](auto dataClause) { return dataClause.getAccPtr(); })
5242 .Default([&](mlir::Operation *) {
5244 })};
5245 return accPtr;
5246}
5247
5249 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
5251 [&](auto dataClause) { return dataClause.getAccVar(); })
5252 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5253 return accPtr;
5254}
5255
5257 auto varPtrPtr{
5259 .Case<ACC_DATA_ENTRY_OPS>(
5260 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
5261 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5262 return varPtrPtr;
5263}
5264
5269 accDataClauseOp)
5270 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5272 dataClause.getBounds().begin(), dataClause.getBounds().end());
5273 })
5274 .Default([&](mlir::Operation *) {
5276 })};
5277 return bounds;
5278}
5279
5283 accDataClauseOp)
5284 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5286 dataClause.getAsyncOperands().begin(),
5287 dataClause.getAsyncOperands().end());
5288 })
5289 .Default([&](mlir::Operation *) {
5291 });
5292}
5293
5294mlir::ArrayAttr
5297 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5298 return dataClause.getAsyncOperandsDeviceTypeAttr();
5299 })
5300 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5301}
5302
5303mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
5306 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
5307 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5308}
5309
5310std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
5311 auto name{
5313 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
5314 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
5315 return {};
5316 })};
5317 return name;
5318}
5319
5320std::optional<mlir::acc::DataClause>
5322 auto dataClause{
5324 accDataEntryOp)
5325 .Case<ACC_DATA_ENTRY_OPS>(
5326 [&](auto entry) { return entry.getDataClause(); })
5327 .Default([&](mlir::Operation *) { return std::nullopt; })};
5328 return dataClause;
5329}
5330
5332 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
5333 .Case<ACC_DATA_ENTRY_OPS>(
5334 [&](auto entry) { return entry.getImplicit(); })
5335 .Default([&](mlir::Operation *) { return false; })};
5336 return implicit;
5337}
5338
5340 auto dataOperands{
5343 [&](auto entry) { return entry.getDataClauseOperands(); })
5344 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
5345 return dataOperands;
5346}
5347
5350 auto dataOperands{
5353 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
5354 .Default([&](mlir::Operation *) { return nullptr; })};
5355 return dataOperands;
5356}
5357
5358mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
5359 auto recipe{
5361 .Case<ACC_DATA_ENTRY_OPS>(
5362 [&](auto entry) { return entry.getRecipeAttr(); })
5363 .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
5364 return recipe;
5365}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4741
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1566
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:3486
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:2072
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4596
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
Definition OpenACC.cpp:933
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:706
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2594
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
Definition OpenACC.cpp:926
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:859
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:686
static bool hasAnyGangWorkerVectorForDeviceType(std::optional< mlir::ArrayAttr > numGangsDeviceType, mlir::Operation::operand_range numGangs, std::optional< llvm::ArrayRef< int32_t > > numGangsSegments, std::optional< mlir::ArrayAttr > numWorkersDeviceType, mlir::Operation::operand_range numWorkers, std::optional< mlir::ArrayAttr > vectorLengthDeviceType, mlir::Operation::operand_range vectorLength, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:2212
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:828
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2605
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2510
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:628
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4798
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:3295
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2845
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:3502
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:4503
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:766
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2799
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:646
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:748
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:782
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1331
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3870
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:2027
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2641
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:2155
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1341
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:774
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:837
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:657
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:670
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2380
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then continue after the...
Definition OpenACC.cpp:525
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:813
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3901
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:536
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4771
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4680
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2493
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2668
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2784
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2447
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
Definition OpenACC.cpp:1325
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2760
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:903
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:3314
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1798
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2829
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:2424
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
Definition OpenACC.cpp:792
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
Definition OpenACC.cpp:2041
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2741
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:632
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:3441
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2679
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:871
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:726
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:2083
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4562
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2430
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2865
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4650
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
false
Parses a map_entries map type from a string format back into its numeric value.
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
static BoolAttr get(MLIRContext *context, bool value)
IntegerType getI64Type()
Definition Builders.cpp:69
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:157
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:435
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
OperandRange operand_range
Definition Operation.h:396
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:607
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
result_range getResults()
Definition Operation.h:440
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
bool isOperation() const
Return true if the successor is an operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Base attribute class for language-specific variable information carried through the OpenACC type inte...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:384
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
#define ACC_COMPUTE_CONSTRUCT_OPS
Definition OpenACC.h:63
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:74
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:49
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:59
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5248
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5217
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5236
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5321
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:5349
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:5266
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:5339
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:5310
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:5331
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
Definition OpenACC.cpp:5358
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:5281
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
Definition OpenACC.h:172
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:5256
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:215
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5303
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5225
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5203
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5295
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.