MLIR 23.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
17#include "mlir/IR/BuiltinOps.h"
20#include "mlir/IR/IRMapping.h"
21#include "mlir/IR/Matchers.h"
23#include "mlir/IR/SymbolTable.h"
24#include "mlir/Support/LLVM.h"
26#include "llvm/ADT/SmallSet.h"
27#include "llvm/ADT/TypeSwitch.h"
28#include "llvm/Support/LogicalResult.h"
29#include <variant>
30
31using namespace mlir;
32using namespace acc;
33
34#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
38#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
39
40namespace {
41
42static bool isScalarLikeType(Type type) {
43 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
44}
45
46/// Helper function to attach the `VarName` attribute to an operation
47/// if a variable name is provided.
48static void attachVarNameAttr(Operation *op, OpBuilder &builder,
49 StringRef varName) {
50 if (!varName.empty()) {
51 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
52 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
53 }
54}
55
56template <typename T>
57struct MemRefPointerLikeModel
58 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
59 Type getElementType(Type pointer) const {
60 return cast<T>(pointer).getElementType();
61 }
62
63 mlir::acc::VariableTypeCategory
64 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
65 Type varType) const {
66 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
67 return mappableTy.getTypeCategory(varPtr);
68 }
69 auto memrefTy = cast<T>(pointer);
70 if (!memrefTy.hasRank()) {
71 // This memref is unranked - aka it could have any rank, including a
72 // rank of 0 which could mean scalar. For now, return uncategorized.
73 return mlir::acc::VariableTypeCategory::uncategorized;
74 }
75
76 if (memrefTy.getRank() == 0) {
77 if (isScalarLikeType(memrefTy.getElementType())) {
78 return mlir::acc::VariableTypeCategory::scalar;
79 }
80 // Zero-rank non-scalar - need further analysis to determine the type
81 // category. For now, return uncategorized.
82 return mlir::acc::VariableTypeCategory::uncategorized;
83 }
84
85 // It has a rank - must be an array.
86 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
87 return mlir::acc::VariableTypeCategory::array;
88 }
89
90 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
91 StringRef varName, Type varType, Value originalVar,
92 bool &needsFree) const {
93 auto memrefTy = cast<MemRefType>(pointer);
94
95 // Check if this is a static memref (all dimensions are known) - if yes
96 // then we can generate an alloca operation.
97 if (memrefTy.hasStaticShape()) {
98 needsFree = false; // alloca doesn't need deallocation
99 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
100 attachVarNameAttr(allocaOp, builder, varName);
101 return allocaOp.getResult();
102 }
103
104 // For dynamic memrefs, extract sizes from the original variable if
105 // provided. Otherwise they cannot be handled.
106 if (originalVar && originalVar.getType() == memrefTy &&
107 memrefTy.hasRank()) {
108 SmallVector<Value> dynamicSizes;
109 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
110 if (memrefTy.isDynamicDim(i)) {
111 // Extract the size of dimension i from the original variable
112 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
113 auto dimSize =
114 memref::DimOp::create(builder, loc, originalVar, indexValue);
115 dynamicSizes.push_back(dimSize);
116 }
117 // Note: We only add dynamic sizes to the dynamicSizes array
118 // Static dimensions are handled automatically by AllocOp
119 }
120 needsFree = true; // alloc needs deallocation
121 auto allocOp =
122 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
123 attachVarNameAttr(allocOp, builder, varName);
124 return allocOp.getResult();
125 }
126
127 // TODO: Unranked not yet supported.
128 return {};
129 }
130
131 bool genFree(Type pointer, OpBuilder &builder, Location loc,
132 TypedValue<PointerLikeType> varToFree, Value allocRes,
133 Type varType) const {
134 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
135 // Use allocRes if provided to determine the allocation type
136 Value valueToInspect = allocRes ? allocRes : memrefValue;
137
138 // Walk through casts to find the original allocation
139 Value currentValue = valueToInspect;
140 Operation *originalAlloc = nullptr;
141
142 // Follow the chain of operations to find the original allocation
143 // even if a casted result is provided.
144 while (currentValue) {
145 if (auto *definingOp = currentValue.getDefiningOp()) {
146 // Check if this is an allocation operation
147 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
148 originalAlloc = definingOp;
149 break;
150 }
151
152 // Check if this is a cast operation we can look through
153 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
154 currentValue = castOp.getSource();
155 continue;
156 }
157
158 // Check for other cast-like operations
159 if (auto reinterpretCastOp =
160 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
161 currentValue = reinterpretCastOp.getSource();
162 continue;
163 }
164
165 // If we can't look through this operation, stop
166 break;
167 }
168 // This is a block argument or similar - can't trace further.
169 break;
170 }
171
172 if (originalAlloc) {
173 if (isa<memref::AllocaOp>(originalAlloc)) {
174 // This is an alloca - no dealloc needed, but return true (success)
175 return true;
176 }
177 if (isa<memref::AllocOp>(originalAlloc)) {
178 // This is an alloc - generate dealloc on varToFree
179 memref::DeallocOp::create(builder, loc, memrefValue);
180 return true;
181 }
182 }
183 }
184
185 return false;
186 }
187
188 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
189 TypedValue<PointerLikeType> destination,
190 TypedValue<PointerLikeType> source, Type varType) const {
191 // Generate a copy operation between two memrefs
192 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
193 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
194
195 // As per memref documentation, source and destination must have same
196 // element type and shape in order to be compatible. We do not want to fail
197 // with an IR verification error - thus check that before generating the
198 // copy operation.
199 if (destMemref && srcMemref &&
200 destMemref.getType().getElementType() ==
201 srcMemref.getType().getElementType() &&
202 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
203 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
204 return true;
205 }
206
207 return false;
208 }
209
210 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
212 Type valueType) const {
213 // Load from a memref - only valid for scalar memrefs (rank 0).
214 // This is because the address computation for memrefs is part of the load
215 // (and not computed separately), but the API does not have arguments for
216 // indexing.
217 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
218 if (!memrefValue)
219 return {};
220
221 auto memrefTy = memrefValue.getType();
222
223 // Only load from scalar memrefs (rank 0)
224 if (memrefTy.getRank() != 0)
225 return {};
226
227 return memref::LoadOp::create(builder, loc, memrefValue);
228 }
229
230 bool genStore(Type pointer, OpBuilder &builder, Location loc,
231 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
232 // Store to a memref - only valid for scalar memrefs (rank 0)
233 // This is because the address computation for memrefs is part of the store
234 // (and not computed separately), but the API does not have arguments for
235 // indexing.
236 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
237 if (!memrefValue)
238 return false;
239
240 auto memrefTy = memrefValue.getType();
241
242 // Only store to scalar memrefs (rank 0)
243 if (memrefTy.getRank() != 0)
244 return false;
245
246 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
247 return true;
248 }
249
250 Value genCast(Type, OpBuilder &builder, Location loc, Value value,
251 Type resultType) const {
252 if (value.getType() == resultType)
253 return value;
254
255 if (isa<BaseMemRefType>(value.getType()) &&
256 isa<BaseMemRefType>(resultType)) {
257 if (memref::CastOp::areCastCompatible(TypeRange(value.getType()),
258 TypeRange(resultType)))
259 return memref::CastOp::create(builder, loc, resultType, value);
260 if (memref::MemorySpaceCastOp::areCastCompatible(
261 TypeRange(value.getType()), TypeRange(resultType)))
262 return memref::MemorySpaceCastOp::create(builder, loc, resultType,
263 value);
264 }
265
266 // If one side is not a memref, try the other type's `PointerLikeType`
267 // implementation (since it may be an out-of-tree reference type that
268 // we cannot generate here).
269 if (auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
270 if (!isa<BaseMemRefType>(resPtrLike))
271 if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
272 return v;
273 if (auto valPtrLike = dyn_cast<PointerLikeType>(value.getType()))
274 if (!isa<BaseMemRefType>(valPtrLike))
275 if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
276 return v;
277
278 return {};
279 }
280
281 bool isDeviceData(Type pointer, Value var) const {
282 auto memrefTy = cast<T>(pointer);
283 Attribute memSpace = memrefTy.getMemorySpace();
284 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
285 }
286};
287
288struct LLVMPointerPointerLikeModel
289 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
290 LLVM::LLVMPointerType> {
291 Type getElementType(Type pointer) const { return Type(); }
292
293 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
295 Type valueType) const {
296 // For LLVM pointers, we need the valueType to determine what to load
297 if (!valueType)
298 return {};
299
300 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
301 }
302
303 bool genStore(Type pointer, OpBuilder &builder, Location loc,
304 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
305 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
306 return true;
307 }
308
309 Value genCast(Type, OpBuilder &builder, Location loc, Value value,
310 Type resultType) const {
311 if (value.getType() == resultType)
312 return value;
313
314 auto srcPtrTy = dyn_cast<LLVM::LLVMPointerType>(value.getType());
315 auto dstPtrTy = dyn_cast<LLVM::LLVMPointerType>(resultType);
316 if (srcPtrTy && dstPtrTy) {
317 if (srcPtrTy.getAddressSpace() != dstPtrTy.getAddressSpace())
318 return LLVM::AddrSpaceCastOp::create(builder, loc, resultType, value);
319 return value;
320 }
321
322 if (srcPtrTy && isa<IntegerType>(resultType))
323 return LLVM::PtrToIntOp::create(builder, loc, resultType, value);
324
325 if (dstPtrTy) {
326 Value intVal = value;
327 if (isa<IndexType>(value.getType()))
328 intVal = arith::IndexCastUIOp::create(builder, loc,
329 builder.getI64Type(), value);
330 if (isa<IntegerType>(intVal.getType()))
331 return LLVM::IntToPtrOp::create(builder, loc, resultType, intVal);
332 }
333
334 if (auto resPtrLike = dyn_cast<PointerLikeType>(resultType))
335 if (!isa<LLVM::LLVMPointerType>(resPtrLike))
336 if (Value v = resPtrLike.genCast(builder, loc, value, resultType))
337 return v;
338 if (auto valPtrLike = dyn_cast<PointerLikeType>(value.getType()))
339 if (!isa<LLVM::LLVMPointerType>(valPtrLike))
340 if (Value v = valPtrLike.genCast(builder, loc, value, resultType))
341 return v;
342
343 return UnrealizedConversionCastOp::create(builder, loc,
344 TypeRange(resultType), value)
345 .getResult(0);
346 }
347};
348
349struct MemrefAddressOfGlobalModel
350 : public AddressOfGlobalOpInterface::ExternalModel<
351 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
352 SymbolRefAttr getSymbol(Operation *op) const {
353 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
354 return getGlobalOp.getNameAttr();
355 }
356};
357
358struct MemrefGlobalVariableModel
359 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
360 memref::GlobalOp> {
361 bool isConstant(Operation *op) const {
362 auto globalOp = cast<memref::GlobalOp>(op);
363 return globalOp.getConstant();
364 }
365
366 Region *getInitRegion(Operation *op) const {
367 // GlobalOp uses attributes for initialization, not regions
368 return nullptr;
369 }
370
371 bool isDeviceData(Operation *op) const {
372 auto globalOp = cast<memref::GlobalOp>(op);
373 Attribute memSpace = globalOp.getType().getMemorySpace();
374 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
375 }
376};
377
378struct GPULaunchOffloadRegionModel
379 : public acc::OffloadRegionOpInterface::ExternalModel<
380 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
381 mlir::Region &getOffloadRegion(mlir::Operation *op) const {
382 return cast<gpu::LaunchOp>(op).getBody();
383 }
384};
385
386/// Helper function for any of the times we need to modify an ArrayAttr based on
387/// a device type list. Returns a new ArrayAttr with all of the
388/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
389/// list is empty).
390mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
391 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
392 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
394 if (existingDeviceTypes)
395 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
396
397 if (newDeviceTypes.empty())
398 deviceTypes.push_back(
399 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
400
401 for (DeviceType dt : newDeviceTypes)
402 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
403
404 return mlir::ArrayAttr::get(context, deviceTypes);
405}
406
407/// Helper function for any of the times we need to add operands that are
408/// affected by a device type list. Returns a new ArrayAttr with all of the
409/// existingDeviceTypes, plus the effective new ones (or an added none, if the
410/// new list is empty). Additionally, adds the arguments to the argCollection
411/// the correct number of times. This will also update a 'segments' array, even
412/// if it won't be used.
413mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
414 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
415 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
416 mlir::MutableOperandRange argCollection,
417 llvm::SmallVector<int32_t> &segments) {
419 if (existingDeviceTypes)
420 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
421
422 if (newDeviceTypes.empty()) {
423 argCollection.append(arguments);
424 segments.push_back(arguments.size());
425 deviceTypes.push_back(
426 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
427 }
428
429 for (DeviceType dt : newDeviceTypes) {
430 argCollection.append(arguments);
431 segments.push_back(arguments.size());
432 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
433 }
434
435 return mlir::ArrayAttr::get(context, deviceTypes);
436}
437
438/// Overload for when the 'segments' aren't needed.
439mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
440 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
441 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
442 mlir::MutableOperandRange argCollection) {
444 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
445 newDeviceTypes, arguments,
446 argCollection, segments);
447}
448} // namespace
449
450//===----------------------------------------------------------------------===//
451// OpenACC operations
452//===----------------------------------------------------------------------===//
453
454void OpenACCDialect::initialize() {
455 addOperations<
456#define GET_OP_LIST
457#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
458 >();
459 addAttributes<
460#define GET_ATTRDEF_LIST
461#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
462 >();
463 addTypes<
464#define GET_TYPEDEF_LIST
465#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
466 >();
467
468 // By attaching interfaces here, we make the OpenACC dialect dependent on
469 // the other dialects. This is probably better than having dialects like LLVM
470 // and memref be dependent on OpenACC.
471 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
472 *getContext());
473 UnrankedMemRefType::attachInterface<
474 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
475 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
476 *getContext());
477
478 // Attach operation interfaces
479 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
480 *getContext());
481 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
482 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext());
483}
484
485//===----------------------------------------------------------------------===//
486// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
487// acc.kernel_environment / acc.data / acc.host_data / acc.loop
488//===----------------------------------------------------------------------===//
489
490/// Generic helper for single-region OpenACC ops that execute their body once
491/// and then return to the parent operation with their results (if any).
492static void
494 RegionBranchPoint point,
496 if (point.isParent()) {
497 regions.push_back(RegionSuccessor(&region));
498 return;
499 }
500
501 regions.push_back(RegionSuccessor::parent());
502}
503
505 RegionSuccessor successor) {
506 return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
507}
508
509void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
511 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
512 regions);
513}
514
515ValueRange KernelsOp::getSuccessorInputs(RegionSuccessor successor) {
516 return getSingleRegionSuccessorInputs(getOperation(), successor);
517}
518
519void ParallelOp::getSuccessorRegions(
521 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
522 regions);
523}
524
525ValueRange ParallelOp::getSuccessorInputs(RegionSuccessor successor) {
526 return getSingleRegionSuccessorInputs(getOperation(), successor);
527}
528
529void SerialOp::getSuccessorRegions(RegionBranchPoint point,
531 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
532 regions);
533}
534
535ValueRange SerialOp::getSuccessorInputs(RegionSuccessor successor) {
536 return getSingleRegionSuccessorInputs(getOperation(), successor);
537}
538
539void DataOp::getSuccessorRegions(RegionBranchPoint point,
541 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
542 regions);
543}
544
545ValueRange DataOp::getSuccessorInputs(RegionSuccessor successor) {
546 return getSingleRegionSuccessorInputs(getOperation(), successor);
547}
548
549void HostDataOp::getSuccessorRegions(
551 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
552 regions);
553}
554
555ValueRange HostDataOp::getSuccessorInputs(RegionSuccessor successor) {
556 return getSingleRegionSuccessorInputs(getOperation(), successor);
557}
558
559void LoopOp::getSuccessorRegions(RegionBranchPoint point,
561 // Unstructured loops: the body may contain arbitrary CFG and early exits.
562 // At the RegionBranch level, only model entry into the body and exit to the
563 // parent; any backedges are represented inside the region CFG.
564 if (getUnstructured()) {
565 if (point.isParent()) {
566 regions.push_back(RegionSuccessor(&getRegion()));
567 return;
568 }
569 regions.push_back(RegionSuccessor::parent());
570 return;
571 }
572
573 // Structured loops: model a loop-shaped region graph similar to scf.for.
574 regions.push_back(RegionSuccessor(&getRegion()));
575 regions.push_back(RegionSuccessor::parent());
576}
577
578ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) {
579 return getSingleRegionSuccessorInputs(getOperation(), successor);
580}
581
582//===----------------------------------------------------------------------===//
583// RegionBranchTerminatorOpInterface
584//===----------------------------------------------------------------------===//
585
587TerminatorOp::getMutableSuccessorOperands(RegionSuccessor /*point*/) {
588 // `acc.terminator` does not forward operands.
589 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
590}
591
592//===----------------------------------------------------------------------===//
593// device_type support helpers
594//===----------------------------------------------------------------------===//
595
596static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
597 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
598}
599
600static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
601 mlir::acc::DeviceType deviceType) {
602 if (!hasDeviceTypeValues(arrayAttr))
603 return false;
604
605 for (auto attr : *arrayAttr) {
606 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
607 if (deviceTypeAttr.getValue() == deviceType)
608 return true;
609 }
610
611 return false;
612}
613
615 std::optional<mlir::ArrayAttr> deviceTypes) {
616 if (!hasDeviceTypeValues(deviceTypes))
617 return;
618
619 p << "[";
620 llvm::interleaveComma(*deviceTypes, p,
621 [&](mlir::Attribute attr) { p << attr; });
622 p << "]";
623}
624
625static std::optional<unsigned> findSegment(ArrayAttr segments,
626 mlir::acc::DeviceType deviceType) {
627 unsigned segmentIdx = 0;
628 for (auto attr : segments) {
629 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
630 if (deviceTypeAttr.getValue() == deviceType)
631 return std::make_optional(segmentIdx);
632 ++segmentIdx;
633 }
634 return std::nullopt;
635}
636
638getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
640 std::optional<llvm::ArrayRef<int32_t>> segments,
641 mlir::acc::DeviceType deviceType) {
642 if (!arrayAttr)
643 return range.take_front(0);
644 if (auto pos = findSegment(*arrayAttr, deviceType)) {
645 int32_t nbOperandsBefore = 0;
646 for (unsigned i = 0; i < *pos; ++i)
647 nbOperandsBefore += (*segments)[i];
648 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
649 }
650 return range.take_front(0);
651}
652
653static mlir::Value
654getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
656 std::optional<llvm::ArrayRef<int32_t>> segments,
657 std::optional<mlir::ArrayAttr> hasWaitDevnum,
658 mlir::acc::DeviceType deviceType) {
659 if (!hasDeviceTypeValues(deviceTypeAttr))
660 return {};
661 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
662 if (hasWaitDevnum->getValue()[*pos])
663 return getValuesFromSegments(deviceTypeAttr, operands, segments,
664 deviceType)
665 .front();
666 return {};
667}
668
670getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
672 std::optional<llvm::ArrayRef<int32_t>> segments,
673 std::optional<mlir::ArrayAttr> hasWaitDevnum,
674 mlir::acc::DeviceType deviceType) {
675 auto range =
676 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
677 if (range.empty())
678 return range;
679 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
680 if (hasWaitDevnum && *hasWaitDevnum) {
681 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
682 if (boolAttr.getValue())
683 return range.drop_front(1); // first value is devnum
684 }
685 }
686 return range;
687}
688
689template <typename Op>
690static LogicalResult checkWaitAndAsyncConflict(Op op) {
691 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
692 ++dtypeInt) {
693 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
694
695 // The asyncOnly attribute represent the async clause without value.
696 // Therefore the attribute and operand cannot appear at the same time.
697 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
698 op.hasAsyncOnly(dtype))
699 return op.emitError(
700 "asyncOnly attribute cannot appear with asyncOperand");
701
702 // The wait attribute represent the wait clause without values. Therefore
703 // the attribute and operands cannot appear at the same time.
704 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
705 op.hasWaitOnly(dtype))
706 return op.emitError("wait attribute cannot appear with waitOperands");
707 }
708 return success();
709}
710
711template <typename Op>
712static LogicalResult checkVarAndVarType(Op op) {
713 if (!op.getVar())
714 return op.emitError("must have var operand");
715
716 // A variable must have a type that is either pointer-like or mappable.
717 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
718 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
719 return op.emitError("var must be mappable or pointer-like");
720
721 // When it is a pointer-like type, the varType must capture the target type.
722 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
723 op.getVarType() == op.getVar().getType())
724 return op.emitError("varType must capture the element type of var");
725
726 return success();
727}
728
729template <typename Op>
730static LogicalResult checkVarAndAccVar(Op op) {
731 if (op.getVar().getType() != op.getAccVar().getType())
732 return op.emitError("input and output types must match");
733
734 return success();
735}
736
737template <typename Op>
738static LogicalResult checkNoModifier(Op op) {
739 if (op.getModifiers() != acc::DataClauseModifier::none)
740 return op.emitError("no data clause modifiers are allowed");
741 return success();
742}
743
744template <typename Op>
745static LogicalResult
746checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
747 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
748 return op.emitError(
749 "invalid data clause modifiers: " +
750 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
751
752 return success();
753}
754
755template <typename OpT, typename RecipeOpT>
756static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
757 // Mappable types do not need a recipe because it is possible to generate one
758 // from its API. Reject reductions though because no API is available for them
759 // at this time.
760 if (mlir::acc::isMappableType(op.getVar().getType()) &&
761 !std::is_same_v<OpT, acc::ReductionOp>)
762 return success();
763
764 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
765 if (!operandRecipe)
766 return op->emitOpError() << "recipe expected for " << operandName;
767
768 auto decl =
770 if (!decl)
771 return op->emitOpError()
772 << "expected symbol reference " << operandRecipe << " to point to a "
773 << operandName << " declaration";
774 return success();
775}
776
777static ParseResult parseVar(mlir::OpAsmParser &parser,
779 // Either `var` or `varPtr` keyword is required.
780 if (failed(parser.parseOptionalKeyword("varPtr"))) {
781 if (failed(parser.parseKeyword("var")))
782 return failure();
783 }
784 if (failed(parser.parseLParen()))
785 return failure();
786 if (failed(parser.parseOperand(var)))
787 return failure();
788
789 return success();
790}
791
793 mlir::Value var) {
794 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
795 p << "varPtr(";
796 else
797 p << "var(";
798 p.printOperand(var);
799}
800
801static ParseResult parseAccVar(mlir::OpAsmParser &parser,
803 mlir::Type &accVarType) {
804 // Either `accVar` or `accPtr` keyword is required.
805 if (failed(parser.parseOptionalKeyword("accPtr"))) {
806 if (failed(parser.parseKeyword("accVar")))
807 return failure();
808 }
809 if (failed(parser.parseLParen()))
810 return failure();
811 if (failed(parser.parseOperand(var)))
812 return failure();
813 if (failed(parser.parseColon()))
814 return failure();
815 if (failed(parser.parseType(accVarType)))
816 return failure();
817 if (failed(parser.parseRParen()))
818 return failure();
819
820 return success();
821}
822
824 mlir::Value accVar, mlir::Type accVarType) {
825 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
826 p << "accPtr(";
827 else
828 p << "accVar(";
829 p.printOperand(accVar);
830 p << " : ";
831 p.printType(accVarType);
832 p << ")";
833}
834
835static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
836 mlir::Type &varPtrType,
837 mlir::TypeAttr &varTypeAttr) {
838 if (failed(parser.parseType(varPtrType)))
839 return failure();
840 if (failed(parser.parseRParen()))
841 return failure();
842
843 if (succeeded(parser.parseOptionalKeyword("varType"))) {
844 if (failed(parser.parseLParen()))
845 return failure();
846 mlir::Type varType;
847 if (failed(parser.parseType(varType)))
848 return failure();
849 varTypeAttr = mlir::TypeAttr::get(varType);
850 if (failed(parser.parseRParen()))
851 return failure();
852 } else {
853 // Set `varType` from the element type of the type of `varPtr`.
854 if (auto ptrTy = dyn_cast<acc::PointerLikeType>(varPtrType)) {
855 Type elementType = ptrTy.getElementType();
856 // Opaque pointers (e.g. !llvm.ptr) have no element type; fall back to
857 // using varPtrType itself so that the attribute is always valid.
858 varTypeAttr = mlir::TypeAttr::get(elementType ? elementType : varPtrType);
859 } else {
860 varTypeAttr = mlir::TypeAttr::get(varPtrType);
861 }
862 }
863
864 return success();
865}
866
868 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
869 p.printType(varPtrType);
870 p << ")";
871
872 // Print the `varType` only if it differs from the element type of
873 // `varPtr`'s type.
874 mlir::Type varType = varTypeAttr.getValue();
875 mlir::Type typeToCheckAgainst =
876 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
877 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
878 : varPtrType;
879 // Opaque pointers (e.g. !llvm.ptr) have no element type; use varPtrType as
880 // the baseline so that the inferred varType is not redundantly printed.
881 if (!typeToCheckAgainst)
882 typeToCheckAgainst = varPtrType;
883 if (typeToCheckAgainst != varType) {
884 p << " varType(";
885 p.printType(varType);
886 p << ")";
887 }
888}
889
890static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
891 mlir::SymbolRefAttr &recipeAttr) {
892 if (failed(parser.parseAttribute(recipeAttr)))
893 return failure();
894 return success();
895}
896
898 mlir::SymbolRefAttr recipeAttr) {
899 p << recipeAttr;
900}
901
902//===----------------------------------------------------------------------===//
903// DataBoundsOp
904//===----------------------------------------------------------------------===//
905LogicalResult acc::DataBoundsOp::verify() {
906 auto extent = getExtent();
907 auto upperbound = getUpperbound();
908 if (!extent && !upperbound)
909 return emitError("expected extent or upperbound.");
910 return success();
911}
912
913//===----------------------------------------------------------------------===//
914// PrivateOp
915//===----------------------------------------------------------------------===//
916LogicalResult acc::PrivateOp::verify() {
917 if (getDataClause() != acc::DataClause::acc_private)
918 return emitError(
919 "data clause associated with private operation must match its intent");
920 if (failed(checkVarAndVarType(*this)))
921 return failure();
922 if (failed(checkNoModifier(*this)))
923 return failure();
924 if (failed(
926 return failure();
927 return success();
928}
929
930//===----------------------------------------------------------------------===//
931// FirstprivateOp
932//===----------------------------------------------------------------------===//
933LogicalResult acc::FirstprivateOp::verify() {
934 if (getDataClause() != acc::DataClause::acc_firstprivate)
935 return emitError("data clause associated with firstprivate operation must "
936 "match its intent");
937 if (failed(checkVarAndVarType(*this)))
938 return failure();
939 if (failed(checkNoModifier(*this)))
940 return failure();
942 *this, "firstprivate")))
943 return failure();
944 return success();
945}
946
947//===----------------------------------------------------------------------===//
948// ReductionOp
949//===----------------------------------------------------------------------===//
950LogicalResult acc::ReductionOp::verify() {
951 if (getDataClause() != acc::DataClause::acc_reduction)
952 return emitError("data clause associated with reduction operation must "
953 "match its intent");
954 if (failed(checkVarAndVarType(*this)))
955 return failure();
956 if (failed(checkNoModifier(*this)))
957 return failure();
959 *this, "reduction")))
960 return failure();
961 return success();
962}
963
964//===----------------------------------------------------------------------===//
965// DevicePtrOp
966//===----------------------------------------------------------------------===//
967LogicalResult acc::DevicePtrOp::verify() {
968 if (getDataClause() != acc::DataClause::acc_deviceptr)
969 return emitError("data clause associated with deviceptr operation must "
970 "match its intent");
971 if (failed(checkVarAndVarType(*this)))
972 return failure();
973 if (failed(checkVarAndAccVar(*this)))
974 return failure();
975 if (failed(checkNoModifier(*this)))
976 return failure();
977 return success();
978}
979
980//===----------------------------------------------------------------------===//
981// PresentOp
982//===----------------------------------------------------------------------===//
983LogicalResult acc::PresentOp::verify() {
984 if (getDataClause() != acc::DataClause::acc_present)
985 return emitError(
986 "data clause associated with present operation must match its intent");
987 if (failed(checkVarAndVarType(*this)))
988 return failure();
989 if (failed(checkVarAndAccVar(*this)))
990 return failure();
991 if (failed(checkNoModifier(*this)))
992 return failure();
993 return success();
994}
995
996//===----------------------------------------------------------------------===//
997// CopyinOp
998//===----------------------------------------------------------------------===//
999LogicalResult acc::CopyinOp::verify() {
1000 // Test for all clauses this operation can be decomposed from:
1001 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
1002 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1003 getDataClause() != acc::DataClause::acc_copy &&
1004 getDataClause() != acc::DataClause::acc_reduction)
1005 return emitError(
1006 "data clause associated with copyin operation must match its intent"
1007 " or specify original clause this operation was decomposed from");
1008 if (failed(checkVarAndVarType(*this)))
1009 return failure();
1010 if (failed(checkVarAndAccVar(*this)))
1011 return failure();
1012 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
1013 acc::DataClauseModifier::always |
1014 acc::DataClauseModifier::capture)))
1015 return failure();
1016 return success();
1017}
1018
1019bool acc::CopyinOp::isCopyinReadonly() {
1020 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
1021 acc::bitEnumContainsAny(getModifiers(),
1022 acc::DataClauseModifier::readonly);
1023}
1024
1025//===----------------------------------------------------------------------===//
1026// CreateOp
1027//===----------------------------------------------------------------------===//
1028LogicalResult acc::CreateOp::verify() {
1029 // Test for all clauses this operation can be decomposed from:
1030 if (getDataClause() != acc::DataClause::acc_create &&
1031 getDataClause() != acc::DataClause::acc_create_zero &&
1032 getDataClause() != acc::DataClause::acc_copyout &&
1033 getDataClause() != acc::DataClause::acc_copyout_zero)
1034 return emitError(
1035 "data clause associated with create operation must match its intent"
1036 " or specify original clause this operation was decomposed from");
1037 if (failed(checkVarAndVarType(*this)))
1038 return failure();
1039 if (failed(checkVarAndAccVar(*this)))
1040 return failure();
1041 // this op is the entry part of copyout, so it also needs to allow all
1042 // modifiers allowed on copyout.
1043 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1044 acc::DataClauseModifier::always |
1045 acc::DataClauseModifier::capture)))
1046 return failure();
1047 return success();
1048}
1049
1050bool acc::CreateOp::isCreateZero() {
1051 // The zero modifier is encoded in the data clause.
1052 return getDataClause() == acc::DataClause::acc_create_zero ||
1053 getDataClause() == acc::DataClause::acc_copyout_zero ||
1054 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1055}
1056
1057//===----------------------------------------------------------------------===//
1058// NoCreateOp
1059//===----------------------------------------------------------------------===//
1060LogicalResult acc::NoCreateOp::verify() {
1061 if (getDataClause() != acc::DataClause::acc_no_create)
1062 return emitError("data clause associated with no_create operation must "
1063 "match its intent");
1064 if (failed(checkVarAndVarType(*this)))
1065 return failure();
1066 if (failed(checkVarAndAccVar(*this)))
1067 return failure();
1068 if (failed(checkNoModifier(*this)))
1069 return failure();
1070 return success();
1071}
1072
1073//===----------------------------------------------------------------------===//
1074// AttachOp
1075//===----------------------------------------------------------------------===//
1076LogicalResult acc::AttachOp::verify() {
1077 if (getDataClause() != acc::DataClause::acc_attach)
1078 return emitError(
1079 "data clause associated with attach operation must match its intent");
1080 if (failed(checkVarAndVarType(*this)))
1081 return failure();
1082 if (failed(checkVarAndAccVar(*this)))
1083 return failure();
1084 if (failed(checkNoModifier(*this)))
1085 return failure();
1086 return success();
1087}
1088
1089//===----------------------------------------------------------------------===//
1090// DeclareDeviceResidentOp
1091//===----------------------------------------------------------------------===//
1092
1093LogicalResult acc::DeclareDeviceResidentOp::verify() {
1094 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
1095 return emitError("data clause associated with device_resident operation "
1096 "must match its intent");
1097 if (failed(checkVarAndVarType(*this)))
1098 return failure();
1099 if (failed(checkVarAndAccVar(*this)))
1100 return failure();
1101 if (failed(checkNoModifier(*this)))
1102 return failure();
1103 return success();
1104}
1105
1106//===----------------------------------------------------------------------===//
1107// DeclareLinkOp
1108//===----------------------------------------------------------------------===//
1109
1110LogicalResult acc::DeclareLinkOp::verify() {
1111 if (getDataClause() != acc::DataClause::acc_declare_link)
1112 return emitError(
1113 "data clause associated with link operation must match its intent");
1114 if (failed(checkVarAndVarType(*this)))
1115 return failure();
1116 if (failed(checkVarAndAccVar(*this)))
1117 return failure();
1118 if (failed(checkNoModifier(*this)))
1119 return failure();
1120 return success();
1121}
1122
1123//===----------------------------------------------------------------------===//
1124// CopyoutOp
1125//===----------------------------------------------------------------------===//
1126LogicalResult acc::CopyoutOp::verify() {
1127 // Test for all clauses this operation can be decomposed from:
1128 if (getDataClause() != acc::DataClause::acc_copyout &&
1129 getDataClause() != acc::DataClause::acc_copyout_zero &&
1130 getDataClause() != acc::DataClause::acc_copy &&
1131 getDataClause() != acc::DataClause::acc_reduction)
1132 return emitError(
1133 "data clause associated with copyout operation must match its intent"
1134 " or specify original clause this operation was decomposed from");
1135 if (!getVar() || !getAccVar())
1136 return emitError("must have both host and device pointers");
1137 if (failed(checkVarAndVarType(*this)))
1138 return failure();
1139 if (failed(checkVarAndAccVar(*this)))
1140 return failure();
1141 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1142 acc::DataClauseModifier::always |
1143 acc::DataClauseModifier::capture)))
1144 return failure();
1145 return success();
1146}
1147
1148bool acc::CopyoutOp::isCopyoutZero() {
1149 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1150 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1151}
1152
1153//===----------------------------------------------------------------------===//
1154// DeleteOp
1155//===----------------------------------------------------------------------===//
1156LogicalResult acc::DeleteOp::verify() {
1157 // Test for all clauses this operation can be decomposed from:
1158 if (getDataClause() != acc::DataClause::acc_delete &&
1159 getDataClause() != acc::DataClause::acc_create &&
1160 getDataClause() != acc::DataClause::acc_create_zero &&
1161 getDataClause() != acc::DataClause::acc_copyin &&
1162 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1163 getDataClause() != acc::DataClause::acc_present &&
1164 getDataClause() != acc::DataClause::acc_no_create &&
1165 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1166 getDataClause() != acc::DataClause::acc_declare_link)
1167 return emitError(
1168 "data clause associated with delete operation must match its intent"
1169 " or specify original clause this operation was decomposed from");
1170 if (!getAccVar())
1171 return emitError("must have device pointer");
1172 // This op is the exit part of copyin and create - thus allow all modifiers
1173 // allowed on either case.
1174 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1175 acc::DataClauseModifier::readonly |
1176 acc::DataClauseModifier::always |
1177 acc::DataClauseModifier::capture)))
1178 return failure();
1179 return success();
1180}
1181
1182//===----------------------------------------------------------------------===//
1183// DetachOp
1184//===----------------------------------------------------------------------===//
1185LogicalResult acc::DetachOp::verify() {
1186 // Test for all clauses this operation can be decomposed from:
1187 if (getDataClause() != acc::DataClause::acc_detach &&
1188 getDataClause() != acc::DataClause::acc_attach)
1189 return emitError(
1190 "data clause associated with detach operation must match its intent"
1191 " or specify original clause this operation was decomposed from");
1192 if (!getAccVar())
1193 return emitError("must have device pointer");
1194 if (failed(checkNoModifier(*this)))
1195 return failure();
1196 return success();
1197}
1198
1199//===----------------------------------------------------------------------===//
1200// HostOp
1201//===----------------------------------------------------------------------===//
1202LogicalResult acc::UpdateHostOp::verify() {
1203 // Test for all clauses this operation can be decomposed from:
1204 if (getDataClause() != acc::DataClause::acc_update_host &&
1205 getDataClause() != acc::DataClause::acc_update_self)
1206 return emitError(
1207 "data clause associated with host operation must match its intent"
1208 " or specify original clause this operation was decomposed from");
1209 if (!getVar() || !getAccVar())
1210 return emitError("must have both host and device pointers");
1211 if (failed(checkVarAndVarType(*this)))
1212 return failure();
1213 if (failed(checkVarAndAccVar(*this)))
1214 return failure();
1215 if (failed(checkNoModifier(*this)))
1216 return failure();
1217 return success();
1218}
1219
1220//===----------------------------------------------------------------------===//
1221// DeviceOp
1222//===----------------------------------------------------------------------===//
1223LogicalResult acc::UpdateDeviceOp::verify() {
1224 // Test for all clauses this operation can be decomposed from:
1225 if (getDataClause() != acc::DataClause::acc_update_device)
1226 return emitError(
1227 "data clause associated with device operation must match its intent"
1228 " or specify original clause this operation was decomposed from");
1229 if (failed(checkVarAndVarType(*this)))
1230 return failure();
1231 if (failed(checkVarAndAccVar(*this)))
1232 return failure();
1233 if (failed(checkNoModifier(*this)))
1234 return failure();
1235 return success();
1236}
1237
1238//===----------------------------------------------------------------------===//
1239// UseDeviceOp
1240//===----------------------------------------------------------------------===//
1241LogicalResult acc::UseDeviceOp::verify() {
1242 // Test for all clauses this operation can be decomposed from:
1243 if (getDataClause() != acc::DataClause::acc_use_device)
1244 return emitError(
1245 "data clause associated with use_device operation must match its intent"
1246 " or specify original clause this operation was decomposed from");
1247 if (failed(checkVarAndVarType(*this)))
1248 return failure();
1249 if (failed(checkVarAndAccVar(*this)))
1250 return failure();
1251 if (failed(checkNoModifier(*this)))
1252 return failure();
1253 return success();
1254}
1255
1256//===----------------------------------------------------------------------===//
1257// CacheOp
1258//===----------------------------------------------------------------------===//
1259LogicalResult acc::CacheOp::verify() {
1260 // Test for all clauses this operation can be decomposed from:
1261 if (getDataClause() != acc::DataClause::acc_cache &&
1262 getDataClause() != acc::DataClause::acc_cache_readonly)
1263 return emitError(
1264 "data clause associated with cache operation must match its intent"
1265 " or specify original clause this operation was decomposed from");
1266 if (failed(checkVarAndVarType(*this)))
1267 return failure();
1268 if (failed(checkVarAndAccVar(*this)))
1269 return failure();
1270 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
1271 return failure();
1272 return success();
1273}
1274
1275bool acc::CacheOp::isCacheReadonly() {
1276 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1277 acc::bitEnumContainsAny(getModifiers(),
1278 acc::DataClauseModifier::readonly);
1279}
1280
1281//===----------------------------------------------------------------------===//
1282// Data entry/exit operations - getEffects implementations
1283//===----------------------------------------------------------------------===//
1284
1285// This function returns true iff the given operation is enclosed
1286// in any ACC_COMPUTE_CONSTRUCT_OPS operation.
1287// It is quite alike acc::getEnclosingComputeOp() utility,
1288// but we cannot use it here.
1292
1293/// Helper to add an effect on an operand, referenced by its mutable range.
1294template <typename EffectTy>
1297 &effects,
1298 MutableOperandRange operand) {
1299 for (unsigned i = 0, e = operand.size(); i < e; ++i)
1300 effects.emplace_back(EffectTy::get(), &operand[i]);
1301}
1302
1303/// Helper to add an effect on a result value.
1304template <typename EffectTy>
1307 &effects,
1308 Value result) {
1309 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
1310}
1311
1312// PrivateOp: accVar result write.
1313void acc::PrivateOp::getEffects(
1315 &effects) {
1316 // If acc.private is enclosed into a compute operation,
1317 // then it denotes the device side privatization, hence
1318 // it does not access the CurrentDeviceIdResource.
1319 if (!isEnclosedIntoComputeOp(getOperation()))
1320 effects.emplace_back(MemoryEffects::Read::get(),
1322 // TODO: should this be MemoryEffects::Allocate?
1324}
1325
1326// FirstprivateOp: var read, accVar result write.
1327void acc::FirstprivateOp::getEffects(
1329 &effects) {
1330 // If acc.firstprivate is enclosed into a compute operation,
1331 // then it denotes the device side privatization, hence
1332 // it does not access the CurrentDeviceIdResource.
1333 if (!isEnclosedIntoComputeOp(getOperation()))
1334 effects.emplace_back(MemoryEffects::Read::get(),
1336 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1338}
1339
1340// ReductionOp: var read, accVar result write.
1341void acc::ReductionOp::getEffects(
1343 &effects) {
1344 // If acc.reduction is enclosed into a compute operation,
1345 // then it denotes the device side reduction, hence
1346 // it does not access the CurrentDeviceIdResource.
1347 if (!isEnclosedIntoComputeOp(getOperation()))
1348 effects.emplace_back(MemoryEffects::Read::get(),
1350 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1352}
1353
1354// DevicePtrOp: RuntimeCounters read.
1355void acc::DevicePtrOp::getEffects(
1357 &effects) {
1358 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1359 effects.emplace_back(MemoryEffects::Read::get(),
1361}
1362
1363// PresentOp: RuntimeCounters read+write.
1364void acc::PresentOp::getEffects(
1366 &effects) {
1367 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1368 effects.emplace_back(MemoryEffects::Write::get(),
1370 effects.emplace_back(MemoryEffects::Read::get(),
1372}
1373
1374// CopyinOp: RuntimeCounters read+write, var read, accVar result write.
1375void acc::CopyinOp::getEffects(
1377 &effects) {
1378 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1379 effects.emplace_back(MemoryEffects::Write::get(),
1381 effects.emplace_back(MemoryEffects::Read::get(),
1383 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1385}
1386
1387// CreateOp: RuntimeCounters read+write, accVar result write.
1388void acc::CreateOp::getEffects(
1390 &effects) {
1391 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1392 effects.emplace_back(MemoryEffects::Write::get(),
1394 effects.emplace_back(MemoryEffects::Read::get(),
1396 // TODO: should this be MemoryEffects::Allocate?
1398}
1399
1400// NoCreateOp: RuntimeCounters read+write.
1401void acc::NoCreateOp::getEffects(
1403 &effects) {
1404 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1405 effects.emplace_back(MemoryEffects::Write::get(),
1407 effects.emplace_back(MemoryEffects::Read::get(),
1409}
1410
1411// AttachOp: RuntimeCounters read+write, var read.
1412void acc::AttachOp::getEffects(
1414 &effects) {
1415 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1416 effects.emplace_back(MemoryEffects::Write::get(),
1418 effects.emplace_back(MemoryEffects::Read::get(),
1420 // TODO: should we also add MemoryEffects::Write?
1421 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1422}
1423
1424// GetDevicePtrOp: RuntimeCounters read.
1425void acc::GetDevicePtrOp::getEffects(
1427 &effects) {
1428 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1429 effects.emplace_back(MemoryEffects::Read::get(),
1431}
1432
1433// UpdateDeviceOp: var read, accVar result write.
1434void acc::UpdateDeviceOp::getEffects(
1436 &effects) {
1437 effects.emplace_back(MemoryEffects::Read::get(),
1439 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1441}
1442
1443// UseDeviceOp: RuntimeCounters read.
1444void acc::UseDeviceOp::getEffects(
1446 &effects) {
1447 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1448 effects.emplace_back(MemoryEffects::Read::get(),
1450}
1451
1452// DeclareDeviceResidentOp: RuntimeCounters write, var read.
1453void acc::DeclareDeviceResidentOp::getEffects(
1455 &effects) {
1456 effects.emplace_back(MemoryEffects::Write::get(),
1458 effects.emplace_back(MemoryEffects::Read::get(),
1460 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1461}
1462
1463// DeclareLinkOp: RuntimeCounters write, var read.
1464void acc::DeclareLinkOp::getEffects(
1466 &effects) {
1467 effects.emplace_back(MemoryEffects::Write::get(),
1469 effects.emplace_back(MemoryEffects::Read::get(),
1471 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1472}
1473
1474// CacheOp: NoMemoryEffect
1475void acc::CacheOp::getEffects(
1477 &effects) {}
1478
1479// CopyoutOp: RuntimeCounters read+write, accVar read, var write.
1480void acc::CopyoutOp::getEffects(
1482 &effects) {
1483 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1484 effects.emplace_back(MemoryEffects::Write::get(),
1486 effects.emplace_back(MemoryEffects::Read::get(),
1488 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1489 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1490}
1491
1492// DeleteOp: RuntimeCounters read+write, accVar read.
1493void acc::DeleteOp::getEffects(
1495 &effects) {
1496 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1497 effects.emplace_back(MemoryEffects::Write::get(),
1499 effects.emplace_back(MemoryEffects::Read::get(),
1501 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1502}
1503
1504// DetachOp: RuntimeCounters read+write, accVar read.
1505void acc::DetachOp::getEffects(
1507 &effects) {
1508 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1509 effects.emplace_back(MemoryEffects::Write::get(),
1511 effects.emplace_back(MemoryEffects::Read::get(),
1513 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1514}
1515
1516// UpdateHostOp: RuntimeCounters read+write, accVar read, var write.
1517void acc::UpdateHostOp::getEffects(
1519 &effects) {
1520 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1521 effects.emplace_back(MemoryEffects::Write::get(),
1523 effects.emplace_back(MemoryEffects::Read::get(),
1525 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1526 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1527}
1528
1529template <typename StructureOp>
1530static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
1531 unsigned nRegions = 1) {
1532
1534 for (unsigned i = 0; i < nRegions; ++i)
1535 regions.push_back(state.addRegion());
1536
1537 for (Region *region : regions)
1538 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
1539 return failure();
1540
1541 return success();
1542}
1543
1544namespace {
1545/// Pattern to remove operation without region that have constant false `ifCond`
1546/// and remove the condition from the operation if the `ifCond` is a true
1547/// constant.
1548template <typename OpTy>
1549struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1550 using OpRewritePattern<OpTy>::OpRewritePattern;
1551
1552 LogicalResult matchAndRewrite(OpTy op,
1553 PatternRewriter &rewriter) const override {
1554 // Early return if there is no condition.
1555 Value ifCond = op.getIfCond();
1556 if (!ifCond)
1557 return failure();
1558
1559 IntegerAttr constAttr;
1560 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1561 return failure();
1562 if (constAttr.getInt())
1563 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1564 else
1565 rewriter.eraseOp(op);
1566
1567 return success();
1568 }
1569};
1570
1571/// Replaces the given op with the contents of the given single-block region,
1572/// using the operands of the block terminator to replace operation results.
1573static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1574 Region &region, ValueRange blockArgs = {}) {
1575 assert(region.hasOneBlock() && "expected single-block region");
1576 Block *block = &region.front();
1577 Operation *terminator = block->getTerminator();
1578 ValueRange results = terminator->getOperands();
1579 rewriter.inlineBlockBefore(block, op, blockArgs);
1580 rewriter.replaceOp(op, results);
1581 rewriter.eraseOp(terminator);
1582}
1583
1584/// Pattern to remove operation with region that have constant false `ifCond`
1585/// and remove the condition from the operation if the `ifCond` is constant
1586/// true.
1587template <typename OpTy>
1588struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1589 using OpRewritePattern<OpTy>::OpRewritePattern;
1590
1591 LogicalResult matchAndRewrite(OpTy op,
1592 PatternRewriter &rewriter) const override {
1593 // Early return if there is no condition.
1594 Value ifCond = op.getIfCond();
1595 if (!ifCond)
1596 return failure();
1597
1598 IntegerAttr constAttr;
1599 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1600 return failure();
1601 if (constAttr.getInt())
1602 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1603 else
1604 replaceOpWithRegion(rewriter, op, op.getRegion());
1605
1606 return success();
1607 }
1608};
1609
1610//===----------------------------------------------------------------------===//
1611// Recipe Region Helpers
1612//===----------------------------------------------------------------------===//
1613
1614/// Create and populate an init region for privatization recipes.
1615/// Returns success if the region is populated, failure otherwise.
1616/// Sets needsFree to indicate if the allocated memory requires deallocation.
1617/// The `hostVar` is the original host variable used to derive
1618/// language-specific metadata via `genPrivateVariableInfo`.
1619/// The `varInfo` output parameter is set to the variable info produced.
1620static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1621 Region &initRegion, Value hostVar,
1622 StringRef varName, ValueRange bounds,
1623 bool &needsFree,
1624 acc::VariableInfoAttr &varInfo) {
1625 Type varType = hostVar.getType();
1626
1627 // Create init block with arguments: original value + bounds
1628 SmallVector<Type> argTypes{varType};
1629 SmallVector<Location> argLocs{loc};
1630 for (Value bound : bounds) {
1631 argTypes.push_back(bound.getType());
1632 argLocs.push_back(loc);
1633 }
1634
1635 Block *initBlock = builder.createBlock(&initRegion);
1636 initBlock->addArguments(argTypes, argLocs);
1637 builder.setInsertionPointToStart(initBlock);
1638
1639 Value privatizedValue;
1640
1641 // Get the block argument that represents the original variable
1642 Value blockArgVar = initBlock->getArgument(0);
1643
1644 // Generate init region body based on variable type
1645 if (isa<MappableType>(varType)) {
1646 auto mappableTy = cast<MappableType>(varType);
1647 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1648 auto typedHostVar = cast<TypedValue<MappableType>>(hostVar);
1649 varInfo = mappableTy.genPrivateVariableInfo(typedHostVar);
1650 privatizedValue = mappableTy.generatePrivateInit(
1651 builder, loc, typedVar, varName, bounds, {}, varInfo, needsFree);
1652 if (!privatizedValue)
1653 return failure();
1654 } else {
1655 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1656 auto pointerLikeTy = cast<PointerLikeType>(varType);
1657 // Use PointerLikeType's allocation API with the block argument
1658 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1659 blockArgVar, needsFree);
1660 if (!privatizedValue)
1661 return failure();
1662 }
1663
1664 // Add yield operation to init block
1665 acc::YieldOp::create(builder, loc, privatizedValue);
1666
1667 return success();
1668}
1669
1670/// Create and populate a copy region for firstprivate recipes.
1671/// Returns success if the region is populated, failure otherwise.
1672/// `varInfo` must be the attribute produced by `createInitRegion` for
1673/// `MappableType` (it is unused for `PointerLikeType` copy paths).
1674static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1675 Region &copyRegion, Type varType,
1676 ValueRange bounds,
1677 acc::VariableInfoAttr varInfo) {
1678 // Create copy block with arguments: original value + privatized value +
1679 // bounds
1680 SmallVector<Type> copyArgTypes{varType, varType};
1681 SmallVector<Location> copyArgLocs{loc, loc};
1682 for (Value bound : bounds) {
1683 copyArgTypes.push_back(bound.getType());
1684 copyArgLocs.push_back(loc);
1685 }
1686
1687 Block *copyBlock = builder.createBlock(&copyRegion);
1688 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1689 builder.setInsertionPointToStart(copyBlock);
1690
1691 Value originalArg = copyBlock->getArgument(0);
1692 Value privatizedArg = copyBlock->getArgument(1);
1693
1694 if (isa<MappableType>(varType)) {
1695 auto mappableTy = cast<MappableType>(varType);
1696 // generateCopy(src, dest): copy from original (arg0) into privatized
1697 // (arg1).
1698 if (!mappableTy.generateCopy(
1699 builder, loc, cast<TypedValue<MappableType>>(originalArg),
1700 cast<TypedValue<MappableType>>(privatizedArg), bounds, varInfo))
1701 return failure();
1702 } else {
1703 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1704 auto pointerLikeTy = cast<PointerLikeType>(varType);
1705 if (!pointerLikeTy.genCopy(
1706 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1707 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1708 return failure();
1709 }
1710
1711 // Add terminator to copy block
1712 acc::TerminatorOp::create(builder, loc);
1713
1714 return success();
1715}
1716
1717/// Create and populate a destroy region for privatization recipes.
1718/// Returns success if the region is populated, failure otherwise.
1719/// The `varInfo` carries language-specific metadata produced by
1720/// `createInitRegion`.
1721static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1722 Region &destroyRegion, Type varType,
1723 Value allocRes, ValueRange bounds,
1724 acc::VariableInfoAttr varInfo) {
1725 // Create destroy block with arguments: original value + privatized value +
1726 // bounds
1727 SmallVector<Type> destroyArgTypes{varType, varType};
1728 SmallVector<Location> destroyArgLocs{loc, loc};
1729 for (Value bound : bounds) {
1730 destroyArgTypes.push_back(bound.getType());
1731 destroyArgLocs.push_back(loc);
1732 }
1733
1734 Block *destroyBlock = builder.createBlock(&destroyRegion);
1735 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1736 builder.setInsertionPointToStart(destroyBlock);
1737
1738 auto varToFree =
1739 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1740 if (isa<MappableType>(varType)) {
1741 auto mappableTy = cast<MappableType>(varType);
1742 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds,
1743 varInfo))
1744 return failure();
1745 } else {
1746 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1747 auto pointerLikeTy = cast<PointerLikeType>(varType);
1748 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1749 return failure();
1750 }
1751
1752 acc::TerminatorOp::create(builder, loc);
1753 return success();
1754}
1755
1756} // namespace
1757
1758//===----------------------------------------------------------------------===//
1759// PrivateRecipeOp
1760//===----------------------------------------------------------------------===//
1761
1763 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1764 Type type, bool verifyYield, bool optional = false) {
1765 if (optional && region.empty())
1766 return success();
1767
1768 if (region.empty())
1769 return op->emitOpError() << "expects non-empty " << regionName << " region";
1770 Block &firstBlock = region.front();
1771 if (firstBlock.getNumArguments() < 1 ||
1772 firstBlock.getArgument(0).getType() != type)
1773 return op->emitOpError() << "expects " << regionName
1774 << " region first "
1775 "argument of the "
1776 << regionType << " type";
1777
1778 if (verifyYield) {
1779 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1780 if (yieldOp.getOperands().size() != 1 ||
1781 yieldOp.getOperands().getTypes()[0] != type)
1782 return op->emitOpError() << "expects " << regionName
1783 << " region to "
1784 "yield a value of the "
1785 << regionType << " type";
1786 }
1787 }
1788 return success();
1789}
1790
1791LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1792 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1793 "privatization", "init", getType(),
1794 /*verifyYield=*/false)))
1795 return failure();
1797 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1798 /*verifyYield=*/false, /*optional=*/true)))
1799 return failure();
1800 return success();
1801}
1802
1803std::optional<PrivateRecipeOp>
1804PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1805 StringRef recipeName, Value hostVar,
1806 StringRef varName, ValueRange bounds) {
1807 Type varType = hostVar.getType();
1808
1809 // First, validate that we can handle this variable type
1810 bool isMappable = isa<MappableType>(varType);
1811 bool isPointerLike = isa<PointerLikeType>(varType);
1812
1813 // Unsupported type
1814 if (!isMappable && !isPointerLike)
1815 return std::nullopt;
1816
1817 OpBuilder::InsertionGuard guard(builder);
1818
1819 // Create the recipe operation first so regions have proper parent context
1820 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1821
1822 // Populate the init region
1823 bool needsFree = false;
1824 acc::VariableInfoAttr varInfo;
1825 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1826 varName, bounds, needsFree, varInfo))) {
1827 recipe.erase();
1828 return std::nullopt;
1829 }
1830
1831 // Only create destroy region if the allocation needs deallocation
1832 if (needsFree) {
1833 // Extract the allocated value from the init block's yield operation
1834 auto yieldOp =
1835 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1836 Value allocRes = yieldOp.getOperand(0);
1837
1838 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1839 varType, allocRes, bounds, varInfo))) {
1840 recipe.erase();
1841 return std::nullopt;
1842 }
1843 }
1844
1845 return recipe;
1846}
1847
1848std::optional<PrivateRecipeOp>
1849PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1850 StringRef recipeName,
1851 FirstprivateRecipeOp firstprivRecipe) {
1852 // Create the private.recipe op with the same type as the firstprivate.recipe.
1853 OpBuilder::InsertionGuard guard(builder);
1854 auto varType = firstprivRecipe.getType();
1855 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1856
1857 // Clone the init region
1858 IRMapping mapping;
1859 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1860
1861 // Clone destroy region if the firstprivate.recipe has one.
1862 if (!firstprivRecipe.getDestroyRegion().empty()) {
1863 IRMapping mapping;
1864 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1865 mapping);
1866 }
1867 return recipe;
1868}
1869
1870//===----------------------------------------------------------------------===//
1871// FirstprivateRecipeOp
1872//===----------------------------------------------------------------------===//
1873
1874LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1875 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1876 "privatization", "init", getType(),
1877 /*verifyYield=*/false)))
1878 return failure();
1879
1880 if (getCopyRegion().empty())
1881 return emitOpError() << "expects non-empty copy region";
1882
1883 Block &firstBlock = getCopyRegion().front();
1884 if (firstBlock.getNumArguments() < 2 ||
1885 firstBlock.getArgument(0).getType() != getType())
1886 return emitOpError() << "expects copy region with two arguments of the "
1887 "privatization type";
1888
1889 if (getDestroyRegion().empty())
1890 return success();
1891
1892 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1893 "privatization", "destroy",
1894 getType(), /*verifyYield=*/false)))
1895 return failure();
1896
1897 return success();
1898}
1899
1900std::optional<FirstprivateRecipeOp>
1901FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1902 StringRef recipeName, Value hostVar,
1903 StringRef varName, ValueRange bounds) {
1904 Type varType = hostVar.getType();
1905
1906 // First, validate that we can handle this variable type
1907 bool isMappable = isa<MappableType>(varType);
1908 bool isPointerLike = isa<PointerLikeType>(varType);
1909
1910 // Unsupported type
1911 if (!isMappable && !isPointerLike)
1912 return std::nullopt;
1913
1914 OpBuilder::InsertionGuard guard(builder);
1915
1916 // Create the recipe operation first so regions have proper parent context
1917 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1918
1919 // Populate the init region
1920 bool needsFree = false;
1921 // Filled by createInitRegion for mappable variables (genPrivateVariableInfo);
1922 // then passed through to copy/destroy so generateCopy /
1923 // generatePrivateDestroy receive the same metadata as generatePrivateInit.
1924 acc::VariableInfoAttr varInfo;
1925 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), hostVar,
1926 varName, bounds, needsFree, varInfo))) {
1927 recipe.erase();
1928 return std::nullopt;
1929 }
1930
1931 // Populate the copy region (uses varInfo for MappableType::generateCopy).
1932 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1933 bounds, varInfo))) {
1934 recipe.erase();
1935 return std::nullopt;
1936 }
1937
1938 // Only create destroy region if the allocation needs deallocation
1939 if (needsFree) {
1940 // Extract the allocated value from the init block's yield operation
1941 auto yieldOp =
1942 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1943 Value allocRes = yieldOp.getOperand(0);
1944
1945 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1946 varType, allocRes, bounds, varInfo))) {
1947 recipe.erase();
1948 return std::nullopt;
1949 }
1950 }
1951
1952 return recipe;
1953}
1954
1955//===----------------------------------------------------------------------===//
1956// ReductionRecipeOp
1957//===----------------------------------------------------------------------===//
1958
1959LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1960 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1961 "init", getType(),
1962 /*verifyYield=*/false)))
1963 return failure();
1964
1965 if (getCombinerRegion().empty())
1966 return emitOpError() << "expects non-empty combiner region";
1967
1968 Block &reductionBlock = getCombinerRegion().front();
1969 if (reductionBlock.getNumArguments() < 2 ||
1970 reductionBlock.getArgument(0).getType() != getType() ||
1971 reductionBlock.getArgument(1).getType() != getType())
1972 return emitOpError() << "expects combiner region with the first two "
1973 << "arguments of the reduction type";
1974
1975 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1976 if (yieldOp.getOperands().size() != 1 ||
1977 yieldOp.getOperands().getTypes()[0] != getType())
1978 return emitOpError() << "expects combiner region to yield a value "
1979 "of the reduction type";
1980 }
1981
1982 return success();
1983}
1984
1985//===----------------------------------------------------------------------===//
1986// ParallelOp
1987//===----------------------------------------------------------------------===//
1988
1989/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1990template <typename Op>
1991static LogicalResult checkDataOperands(Op op,
1992 const mlir::ValueRange &operands) {
1993 for (mlir::Value operand : operands)
1994 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1995 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1996 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1997 operand.getDefiningOp()))
1998 return op.emitError(
1999 "expect data entry/exit operation or acc.getdeviceptr "
2000 "as defining op");
2001 return success();
2002}
2003
2004template <typename OpT, typename RecipeOpT>
2005static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
2006 const mlir::ValueRange &operands,
2007 llvm::StringRef operandName) {
2009 for (mlir::Value operand : operands) {
2010 if (!mlir::isa<OpT>(operand.getDefiningOp()))
2011 return accConstructOp->emitOpError()
2012 << "expected " << operandName << " as defining op";
2013 if (!set.insert(operand).second)
2014 return accConstructOp->emitOpError()
2015 << operandName << " operand appears more than once";
2016 }
2017 return success();
2018}
2019
2020unsigned ParallelOp::getNumDataOperands() {
2021 return getReductionOperands().size() + getPrivateOperands().size() +
2022 getFirstprivateOperands().size() + getDataClauseOperands().size();
2023}
2024
2025Value ParallelOp::getDataOperand(unsigned i) {
2026 unsigned numOptional = getAsyncOperands().size();
2027 numOptional += getNumGangs().size();
2028 numOptional += getNumWorkers().size();
2029 numOptional += getVectorLength().size();
2030 numOptional += getIfCond() ? 1 : 0;
2031 numOptional += getSelfCond() ? 1 : 0;
2032 return getOperand(getWaitOperands().size() + numOptional + i);
2033}
2034
2035template <typename Op>
2036static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
2037 ArrayAttr deviceTypes,
2038 llvm::StringRef keyword) {
2039 if (!operands.empty() &&
2040 (!deviceTypes || deviceTypes.getValue().size() != operands.size()))
2041 return op.emitOpError() << keyword << " operands count must match "
2042 << keyword << " device_type count";
2043 return success();
2044}
2045
2046template <typename Op>
2048 Op op, OperandRange operands, DenseI32ArrayAttr segments,
2049 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
2050 std::size_t numOperandsInSegments = 0;
2051 std::size_t nbOfSegments = 0;
2052
2053 if (segments) {
2054 for (auto segCount : segments.asArrayRef()) {
2055 if (maxInSegment != 0 && segCount > maxInSegment)
2056 return op.emitOpError() << keyword << " expects a maximum of "
2057 << maxInSegment << " values per segment";
2058 numOperandsInSegments += segCount;
2059 ++nbOfSegments;
2060 }
2061 }
2062
2063 if ((numOperandsInSegments != operands.size()) ||
2064 (!deviceTypes && !operands.empty()))
2065 return op.emitOpError()
2066 << keyword << " operand count does not match count in segments";
2067 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
2068 return op.emitOpError()
2069 << keyword << " segment count does not match device_type count";
2070 return success();
2071}
2072
2073LogicalResult acc::ParallelOp::verify() {
2074 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2075 mlir::acc::PrivateRecipeOp>(
2076 *this, getPrivateOperands(), "private")))
2077 return failure();
2078 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2079 mlir::acc::FirstprivateRecipeOp>(
2080 *this, getFirstprivateOperands(), "firstprivate")))
2081 return failure();
2082 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2083 mlir::acc::ReductionRecipeOp>(
2084 *this, getReductionOperands(), "reduction")))
2085 return failure();
2086
2088 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2089 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2090 return failure();
2091
2093 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2094 getWaitOperandsDeviceTypeAttr(), "wait")))
2095 return failure();
2096
2097 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2098 getNumWorkersDeviceTypeAttr(),
2099 "num_workers")))
2100 return failure();
2101
2102 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2103 getVectorLengthDeviceTypeAttr(),
2104 "vector_length")))
2105 return failure();
2106
2108 getAsyncOperandsDeviceTypeAttr(),
2109 "async")))
2110 return failure();
2111
2113 return failure();
2114
2115 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
2116}
2117
2118static mlir::Value
2119getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
2121 mlir::acc::DeviceType deviceType) {
2122 if (!arrayAttr)
2123 return {};
2124 if (auto pos = findSegment(*arrayAttr, deviceType))
2125 return range[*pos];
2126 return {};
2127}
2128
2129bool acc::ParallelOp::hasAsyncOnly() {
2130 return hasAsyncOnly(mlir::acc::DeviceType::None);
2131}
2132
2133bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2134 return hasDeviceType(getAsyncOnly(), deviceType);
2135}
2136
2137mlir::Value acc::ParallelOp::getAsyncValue() {
2138 return getAsyncValue(mlir::acc::DeviceType::None);
2139}
2140
2141mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2143 getAsyncOperands(), deviceType);
2144}
2145
2146mlir::Value acc::ParallelOp::getNumWorkersValue() {
2147 return getNumWorkersValue(mlir::acc::DeviceType::None);
2148}
2149
2151acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2152 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2153 deviceType);
2154}
2155
2156mlir::Value acc::ParallelOp::getVectorLengthValue() {
2157 return getVectorLengthValue(mlir::acc::DeviceType::None);
2158}
2159
2161acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2162 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2163 getVectorLength(), deviceType);
2164}
2165
2166mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
2167 return getNumGangsValues(mlir::acc::DeviceType::None);
2168}
2169
2171ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2172 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2173 getNumGangsSegments(), deviceType);
2174}
2175
2176bool acc::ParallelOp::hasWaitOnly() {
2177 return hasWaitOnly(mlir::acc::DeviceType::None);
2178}
2179
2180bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2181 return hasDeviceType(getWaitOnly(), deviceType);
2182}
2183
2184mlir::Operation::operand_range ParallelOp::getWaitValues() {
2185 return getWaitValues(mlir::acc::DeviceType::None);
2186}
2187
2189ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2191 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2192 getHasWaitDevnum(), deviceType);
2193}
2194
2195mlir::Value ParallelOp::getWaitDevnum() {
2196 return getWaitDevnum(mlir::acc::DeviceType::None);
2197}
2198
2199mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2200 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2201 getWaitOperandsSegments(), getHasWaitDevnum(),
2202 deviceType);
2203}
2204
2205void ParallelOp::build(mlir::OpBuilder &odsBuilder,
2206 mlir::OperationState &odsState,
2207 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
2208 mlir::ValueRange vectorLength,
2209 mlir::ValueRange asyncOperands,
2210 mlir::ValueRange waitOperands, mlir::Value ifCond,
2211 mlir::Value selfCond, mlir::ValueRange reductionOperands,
2212 mlir::ValueRange gangPrivateOperands,
2213 mlir::ValueRange gangFirstPrivateOperands,
2214 mlir::ValueRange dataClauseOperands) {
2215 ParallelOp::build(
2216 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
2217 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
2218 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
2219 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
2220 /*numGangsDeviceType=*/nullptr, numWorkers,
2221 /*numWorkersDeviceType=*/nullptr, vectorLength,
2222 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
2223 /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
2224 gangFirstPrivateOperands, dataClauseOperands,
2225 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
2226}
2227
2228void acc::ParallelOp::addNumWorkersOperand(
2229 MLIRContext *context, mlir::Value newValue,
2230 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2231 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2232 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2233 getNumWorkersMutable()));
2234}
2235void acc::ParallelOp::addVectorLengthOperand(
2236 MLIRContext *context, mlir::Value newValue,
2237 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2238 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2239 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2240 getVectorLengthMutable()));
2241}
2242
2243void acc::ParallelOp::addAsyncOnly(
2244 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2245 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2246 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2247}
2248
2249void acc::ParallelOp::addAsyncOperand(
2250 MLIRContext *context, mlir::Value newValue,
2251 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2252 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2253 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2254 getAsyncOperandsMutable()));
2255}
2256
2257void acc::ParallelOp::addNumGangsOperands(
2258 MLIRContext *context, mlir::ValueRange newValues,
2259 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2261 if (getNumGangsSegments())
2262 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2263
2264 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2265 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2266 getNumGangsMutable(), segments));
2267
2268 setNumGangsSegments(segments);
2269}
2270void acc::ParallelOp::addWaitOnly(
2271 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2272 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2273 effectiveDeviceTypes));
2274}
2275void acc::ParallelOp::addWaitOperands(
2276 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2277 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2278
2280 if (getWaitOperandsSegments())
2281 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2282
2283 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2284 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2285 getWaitOperandsMutable(), segments));
2286 setWaitOperandsSegments(segments);
2287
2289 if (getHasWaitDevnumAttr())
2290 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2291 hasDevnums.insert(
2292 hasDevnums.end(),
2293 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2294 mlir::BoolAttr::get(context, hasDevnum));
2295 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2296}
2297
2298void acc::ParallelOp::addPrivatization(MLIRContext *context,
2299 mlir::acc::PrivateOp op,
2300 mlir::acc::PrivateRecipeOp recipe) {
2301 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2302 getPrivateOperandsMutable().append(op.getResult());
2303}
2304
2305void acc::ParallelOp::addFirstPrivatization(
2306 MLIRContext *context, mlir::acc::FirstprivateOp op,
2307 mlir::acc::FirstprivateRecipeOp recipe) {
2308 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2309 getFirstprivateOperandsMutable().append(op.getResult());
2310}
2311
2312void acc::ParallelOp::addReduction(MLIRContext *context,
2313 mlir::acc::ReductionOp op,
2314 mlir::acc::ReductionRecipeOp recipe) {
2315 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2316 getReductionOperandsMutable().append(op.getResult());
2317}
2318
2319static ParseResult parseNumGangs(
2320 mlir::OpAsmParser &parser,
2322 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2323 mlir::DenseI32ArrayAttr &segments) {
2326
2327 do {
2328 if (failed(parser.parseLBrace()))
2329 return failure();
2330
2331 int32_t crtOperandsSize = operands.size();
2332 if (failed(parser.parseCommaSeparatedList(
2334 if (parser.parseOperand(operands.emplace_back()) ||
2335 parser.parseColonType(types.emplace_back()))
2336 return failure();
2337 return success();
2338 })))
2339 return failure();
2340 seg.push_back(operands.size() - crtOperandsSize);
2341
2342 if (failed(parser.parseRBrace()))
2343 return failure();
2344
2345 if (succeeded(parser.parseOptionalLSquare())) {
2346 if (parser.parseAttribute(attributes.emplace_back()) ||
2347 parser.parseRSquare())
2348 return failure();
2349 } else {
2350 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2351 parser.getContext(), mlir::acc::DeviceType::None));
2352 }
2353 } while (succeeded(parser.parseOptionalComma()));
2354
2355 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2356 attributes.end());
2357 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2358 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2359
2360 return success();
2361}
2362
2364 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2365 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2366 p << " [" << attr << "]";
2367}
2368
2370 mlir::OperandRange operands, mlir::TypeRange types,
2371 std::optional<mlir::ArrayAttr> deviceTypes,
2372 std::optional<mlir::DenseI32ArrayAttr> segments) {
2373 unsigned opIdx = 0;
2374 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2375 p << "{";
2376 llvm::interleaveComma(
2377 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2378 p << operands[opIdx] << " : " << operands[opIdx].getType();
2379 ++opIdx;
2380 });
2381 p << "}";
2382 printSingleDeviceType(p, it.value());
2383 });
2384}
2385
2387 mlir::OpAsmParser &parser,
2389 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2390 mlir::DenseI32ArrayAttr &segments) {
2393
2394 do {
2395 if (failed(parser.parseLBrace()))
2396 return failure();
2397
2398 int32_t crtOperandsSize = operands.size();
2399
2400 if (failed(parser.parseCommaSeparatedList(
2402 if (parser.parseOperand(operands.emplace_back()) ||
2403 parser.parseColonType(types.emplace_back()))
2404 return failure();
2405 return success();
2406 })))
2407 return failure();
2408
2409 seg.push_back(operands.size() - crtOperandsSize);
2410
2411 if (failed(parser.parseRBrace()))
2412 return failure();
2413
2414 if (succeeded(parser.parseOptionalLSquare())) {
2415 if (parser.parseAttribute(attributes.emplace_back()) ||
2416 parser.parseRSquare())
2417 return failure();
2418 } else {
2419 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2420 parser.getContext(), mlir::acc::DeviceType::None));
2421 }
2422 } while (succeeded(parser.parseOptionalComma()));
2423
2424 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2425 attributes.end());
2426 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2427 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2428
2429 return success();
2430}
2431
2434 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2435 std::optional<mlir::DenseI32ArrayAttr> segments) {
2436 unsigned opIdx = 0;
2437 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2438 p << "{";
2439 llvm::interleaveComma(
2440 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2441 p << operands[opIdx] << " : " << operands[opIdx].getType();
2442 ++opIdx;
2443 });
2444 p << "}";
2445 printSingleDeviceType(p, it.value());
2446 });
2447}
2448
2449static ParseResult parseWaitClause(
2450 mlir::OpAsmParser &parser,
2452 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2453 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2454 mlir::ArrayAttr &keywordOnly) {
2455 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2457
2458 bool needCommaBeforeOperands = false;
2459
2460 // Keyword only
2461 if (failed(parser.parseOptionalLParen())) {
2462 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2463 parser.getContext(), mlir::acc::DeviceType::None));
2464 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2465 return success();
2466 }
2467
2468 // Parse keyword only attributes
2469 if (succeeded(parser.parseOptionalLSquare())) {
2470 if (failed(parser.parseCommaSeparatedList([&]() {
2471 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2472 return failure();
2473 return success();
2474 })))
2475 return failure();
2476 if (parser.parseRSquare())
2477 return failure();
2478 needCommaBeforeOperands = true;
2479 }
2480
2481 if (needCommaBeforeOperands && failed(parser.parseComma()))
2482 return failure();
2483
2484 do {
2485 if (failed(parser.parseLBrace()))
2486 return failure();
2487
2488 int32_t crtOperandsSize = operands.size();
2489
2490 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2491 if (failed(parser.parseColon()))
2492 return failure();
2493 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2494 } else {
2495 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2496 }
2497
2498 if (failed(parser.parseCommaSeparatedList(
2500 if (parser.parseOperand(operands.emplace_back()) ||
2501 parser.parseColonType(types.emplace_back()))
2502 return failure();
2503 return success();
2504 })))
2505 return failure();
2506
2507 seg.push_back(operands.size() - crtOperandsSize);
2508
2509 if (failed(parser.parseRBrace()))
2510 return failure();
2511
2512 if (succeeded(parser.parseOptionalLSquare())) {
2513 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2514 parser.parseRSquare())
2515 return failure();
2516 } else {
2517 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2518 parser.getContext(), mlir::acc::DeviceType::None));
2519 }
2520 } while (succeeded(parser.parseOptionalComma()));
2521
2522 if (failed(parser.parseRParen()))
2523 return failure();
2524
2525 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2526 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2527 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2528 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2529
2530 return success();
2531}
2532
2533static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2534 if (!hasDeviceTypeValues(attrs))
2535 return false;
2536 if (attrs->size() != 1)
2537 return false;
2538 if (auto deviceTypeAttr =
2539 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2540 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2541 return false;
2542}
2543
2545 mlir::OperandRange operands, mlir::TypeRange types,
2546 std::optional<mlir::ArrayAttr> deviceTypes,
2547 std::optional<mlir::DenseI32ArrayAttr> segments,
2548 std::optional<mlir::ArrayAttr> hasDevNum,
2549 std::optional<mlir::ArrayAttr> keywordOnly) {
2550
2551 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2552 return;
2553
2554 p << "(";
2555
2556 printDeviceTypes(p, keywordOnly);
2557 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2558 p << ", ";
2559
2560 if (hasDeviceTypeValues(deviceTypes)) {
2561 unsigned opIdx = 0;
2562 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2563 p << "{";
2564 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2565 if (boolAttr && boolAttr.getValue())
2566 p << "devnum: ";
2567 llvm::interleaveComma(
2568 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2569 p << operands[opIdx] << " : " << operands[opIdx].getType();
2570 ++opIdx;
2571 });
2572 p << "}";
2573 printSingleDeviceType(p, it.value());
2574 });
2575 }
2576
2577 p << ")";
2578}
2579
2580static ParseResult parseDeviceTypeOperands(
2581 mlir::OpAsmParser &parser,
2583 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2585 if (failed(parser.parseCommaSeparatedList([&]() {
2586 if (parser.parseOperand(operands.emplace_back()) ||
2587 parser.parseColonType(types.emplace_back()))
2588 return failure();
2589 if (succeeded(parser.parseOptionalLSquare())) {
2590 if (parser.parseAttribute(attributes.emplace_back()) ||
2591 parser.parseRSquare())
2592 return failure();
2593 } else {
2594 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2595 parser.getContext(), mlir::acc::DeviceType::None));
2596 }
2597 return success();
2598 })))
2599 return failure();
2600 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2601 attributes.end());
2602 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2603 return success();
2604}
2605
2606static void
2608 mlir::OperandRange operands, mlir::TypeRange types,
2609 std::optional<mlir::ArrayAttr> deviceTypes) {
2610 if (!hasDeviceTypeValues(deviceTypes))
2611 return;
2612 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2613 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2614 printSingleDeviceType(p, std::get<0>(it));
2615 });
2616}
2617
2619 mlir::OpAsmParser &parser,
2621 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2622 mlir::ArrayAttr &keywordOnlyDeviceType) {
2623
2624 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2625 bool needCommaBeforeOperands = false;
2626
2627 if (failed(parser.parseOptionalLParen())) {
2628 // Keyword only
2629 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2630 parser.getContext(), mlir::acc::DeviceType::None));
2631 keywordOnlyDeviceType =
2632 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2633 return success();
2634 }
2635
2636 // Parse keyword only attributes
2637 if (succeeded(parser.parseOptionalLSquare())) {
2638 // Parse keyword only attributes
2639 if (failed(parser.parseCommaSeparatedList([&]() {
2640 if (parser.parseAttribute(
2641 keywordOnlyDeviceTypeAttributes.emplace_back()))
2642 return failure();
2643 return success();
2644 })))
2645 return failure();
2646 if (parser.parseRSquare())
2647 return failure();
2648 needCommaBeforeOperands = true;
2649 }
2650
2651 if (needCommaBeforeOperands && failed(parser.parseComma()))
2652 return failure();
2653
2655 if (failed(parser.parseCommaSeparatedList([&]() {
2656 if (parser.parseOperand(operands.emplace_back()) ||
2657 parser.parseColonType(types.emplace_back()))
2658 return failure();
2659 if (succeeded(parser.parseOptionalLSquare())) {
2660 if (parser.parseAttribute(attributes.emplace_back()) ||
2661 parser.parseRSquare())
2662 return failure();
2663 } else {
2664 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2665 parser.getContext(), mlir::acc::DeviceType::None));
2666 }
2667 return success();
2668 })))
2669 return failure();
2670
2671 if (failed(parser.parseRParen()))
2672 return failure();
2673
2674 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2675 attributes.end());
2676 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2677 return success();
2678}
2679
2682 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2683 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2684
2685 if (operands.begin() == operands.end() &&
2686 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2687 return;
2688 }
2689
2690 p << "(";
2691 printDeviceTypes(p, keywordOnlyDeviceTypes);
2692 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2693 hasDeviceTypeValues(deviceTypes))
2694 p << ", ";
2695 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2696 p << ")";
2697}
2698
2700 mlir::OpAsmParser &parser,
2701 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2702 mlir::Type &operandType, mlir::UnitAttr &attr) {
2703 // Keyword only
2704 if (failed(parser.parseOptionalLParen())) {
2705 attr = mlir::UnitAttr::get(parser.getContext());
2706 return success();
2707 }
2708
2710 if (failed(parser.parseOperand(op)))
2711 return failure();
2712 operand = op;
2713 if (failed(parser.parseColon()))
2714 return failure();
2715 if (failed(parser.parseType(operandType)))
2716 return failure();
2717 if (failed(parser.parseRParen()))
2718 return failure();
2719
2720 return success();
2721}
2722
2724 mlir::Operation *op,
2725 std::optional<mlir::Value> operand,
2726 mlir::Type operandType,
2727 mlir::UnitAttr attr) {
2728 if (attr)
2729 return;
2730
2731 p << "(";
2732 p.printOperand(*operand);
2733 p << " : ";
2734 p.printType(operandType);
2735 p << ")";
2736}
2737
2739 mlir::OpAsmParser &parser,
2741 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2742 // Keyword only
2743 if (failed(parser.parseOptionalLParen())) {
2744 attr = mlir::UnitAttr::get(parser.getContext());
2745 return success();
2746 }
2747
2748 if (failed(parser.parseCommaSeparatedList([&]() {
2749 if (parser.parseOperand(operands.emplace_back()))
2750 return failure();
2751 return success();
2752 })))
2753 return failure();
2754 if (failed(parser.parseColon()))
2755 return failure();
2756 if (failed(parser.parseCommaSeparatedList([&]() {
2757 if (parser.parseType(types.emplace_back()))
2758 return failure();
2759 return success();
2760 })))
2761 return failure();
2762 if (failed(parser.parseRParen()))
2763 return failure();
2764
2765 return success();
2766}
2767
2769 mlir::Operation *op,
2770 mlir::OperandRange operands,
2771 mlir::TypeRange types,
2772 mlir::UnitAttr attr) {
2773 if (attr)
2774 return;
2775
2776 p << "(";
2777 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2778 p << " : ";
2779 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2780 p << ")";
2781}
2782
2783static ParseResult
2785 mlir::acc::CombinedConstructsTypeAttr &attr) {
2786 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2787 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2788 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2789 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2790 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2791 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2792 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2793 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2794 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2795 } else {
2796 parser.emitError(parser.getCurrentLocation(),
2797 "expected compute construct name");
2798 return failure();
2799 }
2800 return success();
2801}
2802
2803static void
2805 mlir::acc::CombinedConstructsTypeAttr attr) {
2806 if (attr) {
2807 switch (attr.getValue()) {
2808 case mlir::acc::CombinedConstructsType::KernelsLoop:
2809 p << "kernels";
2810 break;
2811 case mlir::acc::CombinedConstructsType::ParallelLoop:
2812 p << "parallel";
2813 break;
2814 case mlir::acc::CombinedConstructsType::SerialLoop:
2815 p << "serial";
2816 break;
2817 };
2818 }
2819}
2820
2821//===----------------------------------------------------------------------===//
2822// SerialOp
2823//===----------------------------------------------------------------------===//
2824
2825unsigned SerialOp::getNumDataOperands() {
2826 return getReductionOperands().size() + getPrivateOperands().size() +
2827 getFirstprivateOperands().size() + getDataClauseOperands().size();
2828}
2829
2830Value SerialOp::getDataOperand(unsigned i) {
2831 unsigned numOptional = getAsyncOperands().size();
2832 numOptional += getIfCond() ? 1 : 0;
2833 numOptional += getSelfCond() ? 1 : 0;
2834 return getOperand(getWaitOperands().size() + numOptional + i);
2835}
2836
2837bool acc::SerialOp::hasAsyncOnly() {
2838 return hasAsyncOnly(mlir::acc::DeviceType::None);
2839}
2840
2841bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2842 return hasDeviceType(getAsyncOnly(), deviceType);
2843}
2844
2845mlir::Value acc::SerialOp::getAsyncValue() {
2846 return getAsyncValue(mlir::acc::DeviceType::None);
2847}
2848
2849mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2851 getAsyncOperands(), deviceType);
2852}
2853
2854bool acc::SerialOp::hasWaitOnly() {
2855 return hasWaitOnly(mlir::acc::DeviceType::None);
2856}
2857
2858bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2859 return hasDeviceType(getWaitOnly(), deviceType);
2860}
2861
2862mlir::Operation::operand_range SerialOp::getWaitValues() {
2863 return getWaitValues(mlir::acc::DeviceType::None);
2864}
2865
2867SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2869 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2870 getHasWaitDevnum(), deviceType);
2871}
2872
2873mlir::Value SerialOp::getWaitDevnum() {
2874 return getWaitDevnum(mlir::acc::DeviceType::None);
2875}
2876
2877mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2878 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2879 getWaitOperandsSegments(), getHasWaitDevnum(),
2880 deviceType);
2881}
2882
2883LogicalResult acc::SerialOp::verify() {
2884 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2885 mlir::acc::PrivateRecipeOp>(
2886 *this, getPrivateOperands(), "private")))
2887 return failure();
2888 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2889 mlir::acc::FirstprivateRecipeOp>(
2890 *this, getFirstprivateOperands(), "firstprivate")))
2891 return failure();
2892 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2893 mlir::acc::ReductionRecipeOp>(
2894 *this, getReductionOperands(), "reduction")))
2895 return failure();
2896
2898 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2899 getWaitOperandsDeviceTypeAttr(), "wait")))
2900 return failure();
2901
2903 getAsyncOperandsDeviceTypeAttr(),
2904 "async")))
2905 return failure();
2906
2908 return failure();
2909
2910 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2911}
2912
2913void acc::SerialOp::addAsyncOnly(
2914 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2915 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2916 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2917}
2918
2919void acc::SerialOp::addAsyncOperand(
2920 MLIRContext *context, mlir::Value newValue,
2921 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2922 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2923 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2924 getAsyncOperandsMutable()));
2925}
2926
2927void acc::SerialOp::addWaitOnly(
2928 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2929 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2930 effectiveDeviceTypes));
2931}
2932void acc::SerialOp::addWaitOperands(
2933 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2934 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2935
2937 if (getWaitOperandsSegments())
2938 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2939
2940 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2941 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2942 getWaitOperandsMutable(), segments));
2943 setWaitOperandsSegments(segments);
2944
2946 if (getHasWaitDevnumAttr())
2947 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2948 hasDevnums.insert(
2949 hasDevnums.end(),
2950 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2951 mlir::BoolAttr::get(context, hasDevnum));
2952 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2953}
2954
2955void acc::SerialOp::addPrivatization(MLIRContext *context,
2956 mlir::acc::PrivateOp op,
2957 mlir::acc::PrivateRecipeOp recipe) {
2958 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2959 getPrivateOperandsMutable().append(op.getResult());
2960}
2961
2962void acc::SerialOp::addFirstPrivatization(
2963 MLIRContext *context, mlir::acc::FirstprivateOp op,
2964 mlir::acc::FirstprivateRecipeOp recipe) {
2965 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2966 getFirstprivateOperandsMutable().append(op.getResult());
2967}
2968
2969void acc::SerialOp::addReduction(MLIRContext *context,
2970 mlir::acc::ReductionOp op,
2971 mlir::acc::ReductionRecipeOp recipe) {
2972 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2973 getReductionOperandsMutable().append(op.getResult());
2974}
2975
2976//===----------------------------------------------------------------------===//
2977// KernelsOp
2978//===----------------------------------------------------------------------===//
2979
2980unsigned KernelsOp::getNumDataOperands() {
2981 return getDataClauseOperands().size();
2982}
2983
2984Value KernelsOp::getDataOperand(unsigned i) {
2985 unsigned numOptional = getAsyncOperands().size();
2986 numOptional += getWaitOperands().size();
2987 numOptional += getNumGangs().size();
2988 numOptional += getNumWorkers().size();
2989 numOptional += getVectorLength().size();
2990 numOptional += getIfCond() ? 1 : 0;
2991 numOptional += getSelfCond() ? 1 : 0;
2992 return getOperand(numOptional + i);
2993}
2994
2995bool acc::KernelsOp::hasAsyncOnly() {
2996 return hasAsyncOnly(mlir::acc::DeviceType::None);
2997}
2998
2999bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3000 return hasDeviceType(getAsyncOnly(), deviceType);
3001}
3002
3003mlir::Value acc::KernelsOp::getAsyncValue() {
3004 return getAsyncValue(mlir::acc::DeviceType::None);
3005}
3006
3007mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3009 getAsyncOperands(), deviceType);
3010}
3011
3012mlir::Value acc::KernelsOp::getNumWorkersValue() {
3013 return getNumWorkersValue(mlir::acc::DeviceType::None);
3014}
3015
3017acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
3018 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
3019 deviceType);
3020}
3021
3022mlir::Value acc::KernelsOp::getVectorLengthValue() {
3023 return getVectorLengthValue(mlir::acc::DeviceType::None);
3024}
3025
3027acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
3028 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
3029 getVectorLength(), deviceType);
3030}
3031
3032mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
3033 return getNumGangsValues(mlir::acc::DeviceType::None);
3034}
3035
3037KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
3038 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
3039 getNumGangsSegments(), deviceType);
3040}
3041
3042bool acc::KernelsOp::hasWaitOnly() {
3043 return hasWaitOnly(mlir::acc::DeviceType::None);
3044}
3045
3046bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3047 return hasDeviceType(getWaitOnly(), deviceType);
3048}
3049
3050mlir::Operation::operand_range KernelsOp::getWaitValues() {
3051 return getWaitValues(mlir::acc::DeviceType::None);
3052}
3053
3055KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3057 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3058 getHasWaitDevnum(), deviceType);
3059}
3060
3061mlir::Value KernelsOp::getWaitDevnum() {
3062 return getWaitDevnum(mlir::acc::DeviceType::None);
3063}
3064
3065mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3066 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3067 getWaitOperandsSegments(), getHasWaitDevnum(),
3068 deviceType);
3069}
3070
3071LogicalResult acc::KernelsOp::verify() {
3073 *this, getNumGangs(), getNumGangsSegmentsAttr(),
3074 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
3075 return failure();
3076
3078 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3079 getWaitOperandsDeviceTypeAttr(), "wait")))
3080 return failure();
3081
3082 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
3083 getNumWorkersDeviceTypeAttr(),
3084 "num_workers")))
3085 return failure();
3086
3087 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
3088 getVectorLengthDeviceTypeAttr(),
3089 "vector_length")))
3090 return failure();
3091
3093 getAsyncOperandsDeviceTypeAttr(),
3094 "async")))
3095 return failure();
3096
3098 return failure();
3099
3100 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
3101}
3102
3103void acc::KernelsOp::addPrivatization(MLIRContext *context,
3104 mlir::acc::PrivateOp op,
3105 mlir::acc::PrivateRecipeOp recipe) {
3106 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3107 getPrivateOperandsMutable().append(op.getResult());
3108}
3109
3110void acc::KernelsOp::addFirstPrivatization(
3111 MLIRContext *context, mlir::acc::FirstprivateOp op,
3112 mlir::acc::FirstprivateRecipeOp recipe) {
3113 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3114 getFirstprivateOperandsMutable().append(op.getResult());
3115}
3116
3117void acc::KernelsOp::addReduction(MLIRContext *context,
3118 mlir::acc::ReductionOp op,
3119 mlir::acc::ReductionRecipeOp recipe) {
3120 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3121 getReductionOperandsMutable().append(op.getResult());
3122}
3123
3124void acc::KernelsOp::addNumWorkersOperand(
3125 MLIRContext *context, mlir::Value newValue,
3126 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3127 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3128 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3129 getNumWorkersMutable()));
3130}
3131
3132void acc::KernelsOp::addVectorLengthOperand(
3133 MLIRContext *context, mlir::Value newValue,
3134 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3135 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3136 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3137 getVectorLengthMutable()));
3138}
3139void acc::KernelsOp::addAsyncOnly(
3140 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3141 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3142 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3143}
3144
3145void acc::KernelsOp::addAsyncOperand(
3146 MLIRContext *context, mlir::Value newValue,
3147 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3148 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3149 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3150 getAsyncOperandsMutable()));
3151}
3152
3153void acc::KernelsOp::addNumGangsOperands(
3154 MLIRContext *context, mlir::ValueRange newValues,
3155 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3157 if (getNumGangsSegmentsAttr())
3158 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3159
3160 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3161 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3162 getNumGangsMutable(), segments));
3163
3164 setNumGangsSegments(segments);
3165}
3166
3167void acc::KernelsOp::addWaitOnly(
3168 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3169 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3170 effectiveDeviceTypes));
3171}
3172void acc::KernelsOp::addWaitOperands(
3173 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3174 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3175
3177 if (getWaitOperandsSegments())
3178 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3179
3180 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3181 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3182 getWaitOperandsMutable(), segments));
3183 setWaitOperandsSegments(segments);
3184
3186 if (getHasWaitDevnumAttr())
3187 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3188 hasDevnums.insert(
3189 hasDevnums.end(),
3190 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3191 mlir::BoolAttr::get(context, hasDevnum));
3192 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3193}
3194
3195//===----------------------------------------------------------------------===//
3196// HostDataOp
3197//===----------------------------------------------------------------------===//
3198
3199LogicalResult acc::HostDataOp::verify() {
3200 if (getDataClauseOperands().empty())
3201 return emitError("at least one operand must appear on the host_data "
3202 "operation");
3203
3205 for (mlir::Value operand : getDataClauseOperands()) {
3206 auto useDeviceOp =
3207 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3208 if (!useDeviceOp)
3209 return emitError("expect data entry operation as defining op");
3210
3211 // Check for duplicate use_device clauses
3212 if (!seenVars.insert(useDeviceOp.getVar()).second)
3213 return emitError("duplicate use_device variable");
3214 }
3215 return success();
3216}
3217
3218void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3219 MLIRContext *context) {
3220 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3221}
3222
3223//===----------------------------------------------------------------------===//
3224// LoopOp
3225//===----------------------------------------------------------------------===//
3226
3227static ParseResult parseGangValue(
3228 OpAsmParser &parser, llvm::StringRef keyword,
3231 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
3232 bool &needCommaBetweenValues, bool &newValue) {
3233 if (succeeded(parser.parseOptionalKeyword(keyword))) {
3234 if (parser.parseEqual())
3235 return failure();
3236 if (parser.parseOperand(operands.emplace_back()) ||
3237 parser.parseColonType(types.emplace_back()))
3238 return failure();
3239 attributes.push_back(gangArgType);
3240 needCommaBetweenValues = true;
3241 newValue = true;
3242 }
3243 return success();
3244}
3245
3246static ParseResult parseGangClause(
3247 OpAsmParser &parser,
3249 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
3250 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
3251 mlir::ArrayAttr &gangOnlyDeviceType) {
3252 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
3253 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
3254 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
3256 bool needCommaBetweenValues = false;
3257 bool needCommaBeforeOperands = false;
3258
3259 if (failed(parser.parseOptionalLParen())) {
3260 // Gang only keyword
3261 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3262 parser.getContext(), mlir::acc::DeviceType::None));
3263 gangOnlyDeviceType =
3264 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
3265 return success();
3266 }
3267
3268 // Parse gang only attributes
3269 if (succeeded(parser.parseOptionalLSquare())) {
3270 // Parse gang only attributes
3271 if (failed(parser.parseCommaSeparatedList([&]() {
3272 if (parser.parseAttribute(
3273 gangOnlyDeviceTypeAttributes.emplace_back()))
3274 return failure();
3275 return success();
3276 })))
3277 return failure();
3278 if (parser.parseRSquare())
3279 return failure();
3280 needCommaBeforeOperands = true;
3281 }
3282
3283 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3284 mlir::acc::GangArgType::Num);
3285 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3286 mlir::acc::GangArgType::Dim);
3287 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3288 parser.getContext(), mlir::acc::GangArgType::Static);
3289
3290 do {
3291 if (needCommaBeforeOperands) {
3292 needCommaBeforeOperands = false;
3293 continue;
3294 }
3295
3296 if (failed(parser.parseLBrace()))
3297 return failure();
3298
3299 int32_t crtOperandsSize = gangOperands.size();
3300 while (true) {
3301 bool newValue = false;
3302 bool needValue = false;
3303 if (needCommaBetweenValues) {
3304 if (succeeded(parser.parseOptionalComma()))
3305 needValue = true; // expect a new value after comma.
3306 else
3307 break;
3308 }
3309
3310 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
3311 gangOperands, gangOperandsType,
3312 gangArgTypeAttributes, argNum,
3313 needCommaBetweenValues, newValue)))
3314 return failure();
3315 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
3316 gangOperands, gangOperandsType,
3317 gangArgTypeAttributes, argDim,
3318 needCommaBetweenValues, newValue)))
3319 return failure();
3320 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3321 gangOperands, gangOperandsType,
3322 gangArgTypeAttributes, argStatic,
3323 needCommaBetweenValues, newValue)))
3324 return failure();
3325
3326 if (!newValue && needValue) {
3327 parser.emitError(parser.getCurrentLocation(),
3328 "new value expected after comma");
3329 return failure();
3330 }
3331
3332 if (!newValue)
3333 break;
3334 }
3335
3336 if (gangOperands.empty())
3337 return parser.emitError(
3338 parser.getCurrentLocation(),
3339 "expect at least one of num, dim or static values");
3340
3341 if (failed(parser.parseRBrace()))
3342 return failure();
3343
3344 if (succeeded(parser.parseOptionalLSquare())) {
3345 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
3346 parser.parseRSquare())
3347 return failure();
3348 } else {
3349 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3350 parser.getContext(), mlir::acc::DeviceType::None));
3351 }
3352
3353 seg.push_back(gangOperands.size() - crtOperandsSize);
3354
3355 } while (succeeded(parser.parseOptionalComma()));
3356
3357 if (failed(parser.parseRParen()))
3358 return failure();
3359
3360 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
3361 gangArgTypeAttributes.end());
3362 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
3363 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
3364
3366 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3367 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
3368
3369 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
3370 return success();
3371}
3372
3374 mlir::OperandRange operands, mlir::TypeRange types,
3375 std::optional<mlir::ArrayAttr> gangArgTypes,
3376 std::optional<mlir::ArrayAttr> deviceTypes,
3377 std::optional<mlir::DenseI32ArrayAttr> segments,
3378 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3379
3380 if (operands.begin() == operands.end() &&
3381 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
3382 return;
3383 }
3384
3385 p << "(";
3386
3387 printDeviceTypes(p, gangOnlyDeviceTypes);
3388
3389 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
3390 hasDeviceTypeValues(deviceTypes))
3391 p << ", ";
3392
3393 if (hasDeviceTypeValues(deviceTypes)) {
3394 unsigned opIdx = 0;
3395 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
3396 p << "{";
3397 llvm::interleaveComma(
3398 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
3399 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3400 (*gangArgTypes)[opIdx]);
3401 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3402 p << LoopOp::getGangNumKeyword();
3403 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3404 p << LoopOp::getGangDimKeyword();
3405 else if (gangArgTypeAttr.getValue() ==
3406 mlir::acc::GangArgType::Static)
3407 p << LoopOp::getGangStaticKeyword();
3408 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
3409 ++opIdx;
3410 });
3411 p << "}";
3412 printSingleDeviceType(p, it.value());
3413 });
3414 }
3415 p << ")";
3416}
3417
3419 std::optional<mlir::ArrayAttr> segments,
3420 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3421 if (!segments)
3422 return false;
3423 for (auto attr : *segments) {
3424 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3425 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3426 return true;
3427 }
3428 return false;
3429}
3430
3431/// Check for duplicates in the DeviceType array attribute.
3432/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3433static std::optional<mlir::acc::DeviceType>
3434checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3435 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3436 if (!deviceTypes)
3437 return std::nullopt;
3438 for (auto attr : deviceTypes) {
3439 auto deviceTypeAttr =
3440 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3441 if (!deviceTypeAttr)
3442 return mlir::acc::DeviceType::None;
3443 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3444 return deviceTypeAttr.getValue();
3445 }
3446 return std::nullopt;
3447}
3448
3449LogicalResult acc::LoopOp::verify() {
3450 if (getUpperbound().size() != getStep().size())
3451 return emitError() << "number of upperbounds expected to be the same as "
3452 "number of steps";
3453
3454 if (getUpperbound().size() != getLowerbound().size())
3455 return emitError() << "number of upperbounds expected to be the same as "
3456 "number of lowerbounds";
3457
3458 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3459 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3460 return emitError() << "inclusiveUpperbound size is expected to be the same"
3461 << " as upperbound size";
3462
3463 // Check collapse
3464 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3465 return emitOpError() << "collapse device_type attr must be define when"
3466 << " collapse attr is present";
3467
3468 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3469 getCollapseAttr().getValue().size() !=
3470 getCollapseDeviceTypeAttr().getValue().size())
3471 return emitOpError() << "collapse attribute count must match collapse"
3472 << " device_type count";
3473 if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3474 return emitOpError() << "duplicate device_type `"
3475 << acc::stringifyDeviceType(*duplicateDeviceType)
3476 << "` found in collapseDeviceType attribute";
3477
3478 // Check gang
3479 if (!getGangOperands().empty()) {
3480 if (!getGangOperandsArgType())
3481 return emitOpError() << "gangOperandsArgType attribute must be defined"
3482 << " when gang operands are present";
3483
3484 if (getGangOperands().size() !=
3485 getGangOperandsArgTypeAttr().getValue().size())
3486 return emitOpError() << "gangOperandsArgType attribute count must match"
3487 << " gangOperands count";
3488 }
3489 if (getGangAttr()) {
3490 if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3491 return emitOpError() << "duplicate device_type `"
3492 << acc::stringifyDeviceType(*duplicateDeviceType)
3493 << "` found in gang attribute";
3494 }
3495
3497 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3498 getGangOperandsDeviceTypeAttr(), "gang")))
3499 return failure();
3500
3501 // Check worker
3502 if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3503 return emitOpError() << "duplicate device_type `"
3504 << acc::stringifyDeviceType(*duplicateDeviceType)
3505 << "` found in worker attribute";
3506 if (auto duplicateDeviceType =
3507 checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3508 return emitOpError() << "duplicate device_type `"
3509 << acc::stringifyDeviceType(*duplicateDeviceType)
3510 << "` found in workerNumOperandsDeviceType attribute";
3511 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3512 getWorkerNumOperandsDeviceTypeAttr(),
3513 "worker")))
3514 return failure();
3515
3516 // Check vector
3517 if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3518 return emitOpError() << "duplicate device_type `"
3519 << acc::stringifyDeviceType(*duplicateDeviceType)
3520 << "` found in vector attribute";
3521 if (auto duplicateDeviceType =
3522 checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3523 return emitOpError() << "duplicate device_type `"
3524 << acc::stringifyDeviceType(*duplicateDeviceType)
3525 << "` found in vectorOperandsDeviceType attribute";
3526 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3527 getVectorOperandsDeviceTypeAttr(),
3528 "vector")))
3529 return failure();
3530
3532 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3533 getTileOperandsDeviceTypeAttr(), "tile")))
3534 return failure();
3535
3536 // auto, independent and seq attribute are mutually exclusive.
3537 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3538 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3539 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3540 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3541 return emitError() << "only one of auto, independent, seq can be present "
3542 "at the same time";
3543 }
3544
3545 // Check that at least one of auto, independent, or seq is present
3546 // for the device-independent default clauses.
3547 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3548 return attr.getValue() == mlir::acc::DeviceType::None;
3549 };
3550 bool hasDefaultSeq =
3551 getSeqAttr()
3552 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3553 hasDeviceNone)
3554 : false;
3555 bool hasDefaultIndependent =
3556 getIndependentAttr()
3557 ? llvm::any_of(
3558 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3559 hasDeviceNone)
3560 : false;
3561 bool hasDefaultAuto =
3562 getAuto_Attr()
3563 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3564 hasDeviceNone)
3565 : false;
3566 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3567 return emitError()
3568 << "at least one of auto, independent, seq must be present";
3569 }
3570
3571 // Gang, worker and vector are incompatible with seq.
3572 if (getSeqAttr()) {
3573 for (auto attr : getSeqAttr()) {
3574 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3575 if (hasVector(deviceTypeAttr.getValue()) ||
3576 getVectorValue(deviceTypeAttr.getValue()) ||
3577 hasWorker(deviceTypeAttr.getValue()) ||
3578 getWorkerValue(deviceTypeAttr.getValue()) ||
3579 hasGang(deviceTypeAttr.getValue()) ||
3580 getGangValue(mlir::acc::GangArgType::Num,
3581 deviceTypeAttr.getValue()) ||
3582 getGangValue(mlir::acc::GangArgType::Dim,
3583 deviceTypeAttr.getValue()) ||
3584 getGangValue(mlir::acc::GangArgType::Static,
3585 deviceTypeAttr.getValue()))
3586 return emitError() << "gang, worker or vector cannot appear with seq";
3587 }
3588 }
3589
3590 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
3591 mlir::acc::PrivateRecipeOp>(
3592 *this, getPrivateOperands(), "private")))
3593 return failure();
3594
3595 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
3596 mlir::acc::FirstprivateRecipeOp>(
3597 *this, getFirstprivateOperands(), "firstprivate")))
3598 return failure();
3599
3600 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
3601 mlir::acc::ReductionRecipeOp>(
3602 *this, getReductionOperands(), "reduction")))
3603 return failure();
3604
3605 if (getCombined().has_value() &&
3606 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3607 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3608 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3609 return emitError("unexpected combined constructs attribute");
3610 }
3611
3612 // Check non-empty body().
3613 if (getRegion().empty())
3614 return emitError("expected non-empty body.");
3615
3616 if (getUnstructured()) {
3617 if (!isContainerLike())
3618 return emitError(
3619 "unstructured acc.loop must not have induction variables");
3620 } else if (isContainerLike()) {
3621 // When it is container-like - it is expected to hold a loop-like operation.
3622 // Obtain the maximum collapse count - we use this to check that there
3623 // are enough loops contained.
3624 uint64_t collapseCount = getCollapseValue().value_or(1);
3625 if (getCollapseAttr()) {
3626 for (auto collapseEntry : getCollapseAttr()) {
3627 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3628 if (intAttr.getValue().getZExtValue() > collapseCount)
3629 collapseCount = intAttr.getValue().getZExtValue();
3630 }
3631 }
3632
3633 // We want to check that we find enough loop-like operations inside.
3634 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3635 // level.
3636 mlir::Operation *expectedParent = this->getOperation();
3637 bool foundSibling = false;
3638 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3639 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3640 // This effectively checks that we are not looking at a sibling loop.
3641 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3642 expectedParent) {
3643 foundSibling = true;
3645 }
3646
3647 collapseCount--;
3648 expectedParent = op;
3649 }
3650 // We found enough contained loops.
3651 if (collapseCount == 0)
3654 });
3655
3656 if (foundSibling)
3657 return emitError("found sibling loops inside container-like acc.loop");
3658 if (collapseCount != 0)
3659 return emitError("failed to find enough loop-like operations inside "
3660 "container-like acc.loop");
3661 }
3662
3663 return success();
3664}
3665
3666unsigned LoopOp::getNumDataOperands() {
3667 return getReductionOperands().size() + getPrivateOperands().size() +
3668 getFirstprivateOperands().size();
3669}
3670
3671Value LoopOp::getDataOperand(unsigned i) {
3672 unsigned numOptional =
3673 getLowerbound().size() + getUpperbound().size() + getStep().size();
3674 numOptional += getGangOperands().size();
3675 numOptional += getVectorOperands().size();
3676 numOptional += getWorkerNumOperands().size();
3677 numOptional += getTileOperands().size();
3678 numOptional += getCacheOperands().size();
3679 return getOperand(numOptional + i);
3680}
3681
3682bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3683
3684bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3685 return hasDeviceType(getAuto_(), deviceType);
3686}
3687
3688bool LoopOp::hasIndependent() {
3689 return hasIndependent(mlir::acc::DeviceType::None);
3690}
3691
3692bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3693 return hasDeviceType(getIndependent(), deviceType);
3694}
3695
3696bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3697
3698bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3699 return hasDeviceType(getSeq(), deviceType);
3700}
3701
3702mlir::Value LoopOp::getVectorValue() {
3703 return getVectorValue(mlir::acc::DeviceType::None);
3704}
3705
3706mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3707 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3708 getVectorOperands(), deviceType);
3709}
3710
3711bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3712
3713bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3714 return hasDeviceType(getVector(), deviceType);
3715}
3716
3717mlir::Value LoopOp::getWorkerValue() {
3718 return getWorkerValue(mlir::acc::DeviceType::None);
3719}
3720
3721mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3722 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3723 getWorkerNumOperands(), deviceType);
3724}
3725
3726bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3727
3728bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3729 return hasDeviceType(getWorker(), deviceType);
3730}
3731
3732mlir::Operation::operand_range LoopOp::getTileValues() {
3733 return getTileValues(mlir::acc::DeviceType::None);
3734}
3735
3737LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3738 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3739 getTileOperandsSegments(), deviceType);
3740}
3741
3742std::optional<int64_t> LoopOp::getCollapseValue() {
3743 return getCollapseValue(mlir::acc::DeviceType::None);
3744}
3745
3746std::optional<int64_t>
3747LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3748 if (!getCollapseAttr())
3749 return std::nullopt;
3750 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3751 auto intAttr =
3752 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3753 return intAttr.getValue().getZExtValue();
3754 }
3755 return std::nullopt;
3756}
3757
3758mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3759 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3760}
3761
3762mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3763 mlir::acc::DeviceType deviceType) {
3764 if (getGangOperands().empty())
3765 return {};
3766 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3767 int32_t nbOperandsBefore = 0;
3768 for (unsigned i = 0; i < *pos; ++i)
3769 nbOperandsBefore += (*getGangOperandsSegments())[i];
3771 getGangOperands()
3772 .drop_front(nbOperandsBefore)
3773 .take_front((*getGangOperandsSegments())[*pos]);
3774
3775 int32_t argTypeIdx = nbOperandsBefore;
3776 for (auto value : values) {
3777 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3778 (*getGangOperandsArgType())[argTypeIdx]);
3779 if (gangArgTypeAttr.getValue() == gangArgType)
3780 return value;
3781 ++argTypeIdx;
3782 }
3783 }
3784 return {};
3785}
3786
3787bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3788
3789bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3790 return hasDeviceType(getGang(), deviceType);
3791}
3792
3793llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3794 return {&getRegion()};
3795}
3796
3797/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3798/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3799/// `(` ssa-id-and-type-list `)`
3800/// region
3801ParseResult
3804 SmallVectorImpl<Type> &lowerboundType,
3806 SmallVectorImpl<Type> &upperboundType,
3808 SmallVectorImpl<Type> &stepType) {
3809
3811 if (succeeded(
3812 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3813 if (parser.parseLParen() ||
3814 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3815 /*allowType=*/true) ||
3816 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3817 parser.parseOperandList(lowerbound, inductionVars.size(),
3819 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3820 parser.parseKeyword("to") || parser.parseLParen() ||
3821 parser.parseOperandList(upperbound, inductionVars.size(),
3823 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3824 parser.parseKeyword("step") || parser.parseLParen() ||
3825 parser.parseOperandList(step, inductionVars.size(),
3827 parser.parseColonTypeList(stepType) || parser.parseRParen())
3828 return failure();
3829 }
3830 return parser.parseRegion(region, inductionVars);
3831}
3832
3834 ValueRange lowerbound, TypeRange lowerboundType,
3835 ValueRange upperbound, TypeRange upperboundType,
3836 ValueRange steps, TypeRange stepType) {
3837 ValueRange regionArgs = region.front().getArguments();
3838 if (!regionArgs.empty()) {
3839 p << acc::LoopOp::getControlKeyword() << "(";
3840 llvm::interleaveComma(regionArgs, p,
3841 [&p](Value v) { p << v << " : " << v.getType(); });
3842 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3843 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3844 << " : " << stepType << ") ";
3845 }
3846 p.printRegion(region, /*printEntryBlockArgs=*/false);
3847}
3848
3849void acc::LoopOp::addSeq(MLIRContext *context,
3850 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3851 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3852 effectiveDeviceTypes));
3853}
3854
3855void acc::LoopOp::addIndependent(
3856 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3857 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3858 context, getIndependentAttr(), effectiveDeviceTypes));
3859}
3860
3861void acc::LoopOp::addAuto(MLIRContext *context,
3862 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3863 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3864 effectiveDeviceTypes));
3865}
3866
3867void acc::LoopOp::setCollapseForDeviceTypes(
3868 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3869 llvm::APInt value) {
3872
3873 assert((getCollapseAttr() == nullptr) ==
3874 (getCollapseDeviceTypeAttr() == nullptr));
3875 assert(value.getBitWidth() == 64);
3876
3877 if (getCollapseAttr()) {
3878 for (const auto &existing :
3879 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3880 newValues.push_back(std::get<0>(existing));
3881 newDeviceTypes.push_back(std::get<1>(existing));
3882 }
3883 }
3884
3885 if (effectiveDeviceTypes.empty()) {
3886 // If the effective device-types list is empty, this is before there are any
3887 // being applied by device_type, so this should be added as a 'none'.
3888 newValues.push_back(
3889 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3890 newDeviceTypes.push_back(
3891 acc::DeviceTypeAttr::get(context, DeviceType::None));
3892 } else {
3893 for (DeviceType dt : effectiveDeviceTypes) {
3894 newValues.push_back(
3895 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3896 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3897 }
3898 }
3899
3900 setCollapseAttr(ArrayAttr::get(context, newValues));
3901 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3902}
3903
3904void acc::LoopOp::setTileForDeviceTypes(
3905 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3906 ValueRange values) {
3908 if (getTileOperandsSegments())
3909 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3910
3911 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3912 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3913 getTileOperandsMutable(), segments));
3914
3915 setTileOperandsSegments(segments);
3916}
3917
3918void acc::LoopOp::addVectorOperand(
3919 MLIRContext *context, mlir::Value newValue,
3920 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3921 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3922 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3923 newValue, getVectorOperandsMutable()));
3924}
3925
3926void acc::LoopOp::addEmptyVector(
3927 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3928 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3929 effectiveDeviceTypes));
3930}
3931
3932void acc::LoopOp::addWorkerNumOperand(
3933 MLIRContext *context, mlir::Value newValue,
3934 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3935 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3936 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3937 newValue, getWorkerNumOperandsMutable()));
3938}
3939
3940void acc::LoopOp::addEmptyWorker(
3941 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3942 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3943 effectiveDeviceTypes));
3944}
3945
3946void acc::LoopOp::addEmptyGang(
3947 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3948 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3949 effectiveDeviceTypes));
3950}
3951
3952bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3953 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3954 return attr.getValue() == dt;
3955 };
3956 auto testFromArr = [=](ArrayAttr arr) -> bool {
3957 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3958 };
3959
3960 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3961 return true;
3962 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3963 return true;
3964 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3965 return true;
3966
3967 return false;
3968}
3969
3970bool acc::LoopOp::hasDefaultGangWorkerVector() {
3971 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3972 hasGang() || getGangValue(GangArgType::Num) ||
3973 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3974}
3975
3976acc::LoopParMode
3977acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3978 if (hasSeq(deviceType))
3979 return LoopParMode::loop_seq;
3980 if (hasAuto(deviceType))
3981 return LoopParMode::loop_auto;
3982 if (hasIndependent(deviceType))
3983 return LoopParMode::loop_independent;
3984 if (hasSeq())
3985 return LoopParMode::loop_seq;
3986 if (hasAuto())
3987 return LoopParMode::loop_auto;
3988 assert(hasIndependent() &&
3989 "loop must have default auto, seq, or independent");
3990 return LoopParMode::loop_independent;
3991}
3992
3993void acc::LoopOp::addGangOperands(
3994 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3997 if (std::optional<ArrayRef<int32_t>> existingSegments =
3998 getGangOperandsSegments())
3999 llvm::copy(*existingSegments, std::back_inserter(segments));
4000
4001 unsigned beforeCount = segments.size();
4002
4003 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4004 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
4005 getGangOperandsMutable(), segments));
4006
4007 setGangOperandsSegments(segments);
4008
4009 // This is a bit of extra work to make sure we update the 'types' correctly by
4010 // adding to the types collection the correct number of times. We could
4011 // potentially add something similar to the
4012 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
4013 // excessive for a one-off case.
4014 unsigned numAdded = segments.size() - beforeCount;
4015
4016 if (numAdded > 0) {
4018 if (getGangOperandsArgTypeAttr())
4019 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
4020
4021 for (auto i : llvm::index_range(0u, numAdded)) {
4022 llvm::transform(argTypes, std::back_inserter(gangTypes),
4023 [=](mlir::acc::GangArgType gangTy) {
4024 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
4025 });
4026 (void)i;
4027 }
4028
4029 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
4030 }
4031}
4032
4033void acc::LoopOp::addPrivatization(MLIRContext *context,
4034 mlir::acc::PrivateOp op,
4035 mlir::acc::PrivateRecipeOp recipe) {
4036 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4037 getPrivateOperandsMutable().append(op.getResult());
4038}
4039
4040void acc::LoopOp::addFirstPrivatization(
4041 MLIRContext *context, mlir::acc::FirstprivateOp op,
4042 mlir::acc::FirstprivateRecipeOp recipe) {
4043 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4044 getFirstprivateOperandsMutable().append(op.getResult());
4045}
4046
4047void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
4048 mlir::acc::ReductionRecipeOp recipe) {
4049 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4050 getReductionOperandsMutable().append(op.getResult());
4051}
4052
4053//===----------------------------------------------------------------------===//
4054// DataOp
4055//===----------------------------------------------------------------------===//
4056
4057LogicalResult acc::DataOp::verify() {
4058 // 2.6.5. Data Construct restriction
4059 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
4060 // attach, or default clause must appear on a data construct.
4061 if (getOperands().empty() && !getDefaultAttr())
4062 return emitError("at least one operand or the default attribute "
4063 "must appear on the data operation");
4064
4065 for (mlir::Value operand : getDataClauseOperands())
4066 if (isa<BlockArgument>(operand) ||
4067 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4068 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
4069 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
4070 operand.getDefiningOp()))
4071 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4072 "as defining op");
4073
4075 return failure();
4076
4077 return success();
4078}
4079
4080unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
4081
4082Value DataOp::getDataOperand(unsigned i) {
4083 unsigned numOptional = getIfCond() ? 1 : 0;
4084 numOptional += getAsyncOperands().size() ? 1 : 0;
4085 numOptional += getWaitOperands().size();
4086 return getOperand(numOptional + i);
4087}
4088
4089bool acc::DataOp::hasAsyncOnly() {
4090 return hasAsyncOnly(mlir::acc::DeviceType::None);
4091}
4092
4093bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4094 return hasDeviceType(getAsyncOnly(), deviceType);
4095}
4096
4097mlir::Value DataOp::getAsyncValue() {
4098 return getAsyncValue(mlir::acc::DeviceType::None);
4099}
4100
4101mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4103 getAsyncOperands(), deviceType);
4104}
4105
4106bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
4107
4108bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4109 return hasDeviceType(getWaitOnly(), deviceType);
4110}
4111
4112mlir::Operation::operand_range DataOp::getWaitValues() {
4113 return getWaitValues(mlir::acc::DeviceType::None);
4114}
4115
4117DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4119 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4120 getHasWaitDevnum(), deviceType);
4121}
4122
4123mlir::Value DataOp::getWaitDevnum() {
4124 return getWaitDevnum(mlir::acc::DeviceType::None);
4125}
4126
4127mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4128 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4129 getWaitOperandsSegments(), getHasWaitDevnum(),
4130 deviceType);
4131}
4132
4133void acc::DataOp::addAsyncOnly(
4134 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4135 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4136 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4137}
4138
4139void acc::DataOp::addAsyncOperand(
4140 MLIRContext *context, mlir::Value newValue,
4141 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4142 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4143 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4144 getAsyncOperandsMutable()));
4145}
4146
4147void acc::DataOp::addWaitOnly(MLIRContext *context,
4148 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4149 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4150 effectiveDeviceTypes));
4151}
4152
4153void acc::DataOp::addWaitOperands(
4154 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4155 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4156
4158 if (getWaitOperandsSegments())
4159 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4160
4161 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4162 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4163 getWaitOperandsMutable(), segments));
4164 setWaitOperandsSegments(segments);
4165
4167 if (getHasWaitDevnumAttr())
4168 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4169 hasDevnums.insert(
4170 hasDevnums.end(),
4171 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4172 mlir::BoolAttr::get(context, hasDevnum));
4173 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4174}
4175
4176//===----------------------------------------------------------------------===//
4177// ExitDataOp
4178//===----------------------------------------------------------------------===//
4179
4180LogicalResult acc::ExitDataOp::verify() {
4181 // 2.6.6. Data Exit Directive restriction
4182 // At least one copyout, delete, or detach clause must appear on an exit data
4183 // directive.
4184 if (getDataClauseOperands().empty())
4185 return emitError("at least one operand must be present in dataOperands on "
4186 "the exit data operation");
4187
4188 // The async attribute represent the async clause without value. Therefore the
4189 // attribute and operand cannot appear at the same time.
4190 if (getAsyncOperand() && getAsync())
4191 return emitError("async attribute cannot appear with asyncOperand");
4192
4193 // The wait attribute represent the wait clause without values. Therefore the
4194 // attribute and operands cannot appear at the same time.
4195 if (!getWaitOperands().empty() && getWait())
4196 return emitError("wait attribute cannot appear with waitOperands");
4197
4198 if (getWaitDevnum() && getWaitOperands().empty())
4199 return emitError("wait_devnum cannot appear without waitOperands");
4200
4201 return success();
4202}
4203
4204unsigned ExitDataOp::getNumDataOperands() {
4205 return getDataClauseOperands().size();
4206}
4207
4208Value ExitDataOp::getDataOperand(unsigned i) {
4209 unsigned numOptional = getIfCond() ? 1 : 0;
4210 numOptional += getAsyncOperand() ? 1 : 0;
4211 numOptional += getWaitDevnum() ? 1 : 0;
4212 return getOperand(getWaitOperands().size() + numOptional + i);
4213}
4214
4215void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4216 MLIRContext *context) {
4217 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
4218}
4219
4220void ExitDataOp::addAsyncOnly(MLIRContext *context,
4221 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4222 assert(effectiveDeviceTypes.empty());
4223 assert(!getAsyncAttr());
4224 assert(!getAsyncOperand());
4225
4226 setAsyncAttr(mlir::UnitAttr::get(context));
4227}
4228
4229void ExitDataOp::addAsyncOperand(
4230 MLIRContext *context, mlir::Value newValue,
4231 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4232 assert(effectiveDeviceTypes.empty());
4233 assert(!getAsyncAttr());
4234 assert(!getAsyncOperand());
4235
4236 getAsyncOperandMutable().append(newValue);
4237}
4238
4239void ExitDataOp::addWaitOnly(MLIRContext *context,
4240 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4241 assert(effectiveDeviceTypes.empty());
4242 assert(!getWaitAttr());
4243 assert(getWaitOperands().empty());
4244 assert(!getWaitDevnum());
4245
4246 setWaitAttr(mlir::UnitAttr::get(context));
4247}
4248
4249void ExitDataOp::addWaitOperands(
4250 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4251 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4252 assert(effectiveDeviceTypes.empty());
4253 assert(!getWaitAttr());
4254 assert(getWaitOperands().empty());
4255 assert(!getWaitDevnum());
4256
4257 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4258 // operands list.
4259 if (hasDevnum) {
4260 getWaitDevnumMutable().append(newValues.front());
4261 newValues = newValues.drop_front();
4262 }
4263
4264 getWaitOperandsMutable().append(newValues);
4265}
4266
4267//===----------------------------------------------------------------------===//
4268// EnterDataOp
4269//===----------------------------------------------------------------------===//
4270
4271LogicalResult acc::EnterDataOp::verify() {
4272 // 2.6.6. Data Enter Directive restriction
4273 // At least one copyin, create, or attach clause must appear on an enter data
4274 // directive.
4275 if (getDataClauseOperands().empty())
4276 return emitError("at least one operand must be present in dataOperands on "
4277 "the enter data operation");
4278
4279 // The async attribute represent the async clause without value. Therefore the
4280 // attribute and operand cannot appear at the same time.
4281 if (getAsyncOperand() && getAsync())
4282 return emitError("async attribute cannot appear with asyncOperand");
4283
4284 // The wait attribute represent the wait clause without values. Therefore the
4285 // attribute and operands cannot appear at the same time.
4286 if (!getWaitOperands().empty() && getWait())
4287 return emitError("wait attribute cannot appear with waitOperands");
4288
4289 if (getWaitDevnum() && getWaitOperands().empty())
4290 return emitError("wait_devnum cannot appear without waitOperands");
4291
4292 for (mlir::Value operand : getDataClauseOperands())
4293 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4294 operand.getDefiningOp()))
4295 return emitError("expect data entry operation as defining op");
4296
4297 return success();
4298}
4299
4300unsigned EnterDataOp::getNumDataOperands() {
4301 return getDataClauseOperands().size();
4302}
4303
4304Value EnterDataOp::getDataOperand(unsigned i) {
4305 unsigned numOptional = getIfCond() ? 1 : 0;
4306 numOptional += getAsyncOperand() ? 1 : 0;
4307 numOptional += getWaitDevnum() ? 1 : 0;
4308 return getOperand(getWaitOperands().size() + numOptional + i);
4309}
4310
4311void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4312 MLIRContext *context) {
4313 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
4314}
4315
4316void EnterDataOp::addAsyncOnly(
4317 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4318 assert(effectiveDeviceTypes.empty());
4319 assert(!getAsyncAttr());
4320 assert(!getAsyncOperand());
4321
4322 setAsyncAttr(mlir::UnitAttr::get(context));
4323}
4324
4325void EnterDataOp::addAsyncOperand(
4326 MLIRContext *context, mlir::Value newValue,
4327 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4328 assert(effectiveDeviceTypes.empty());
4329 assert(!getAsyncAttr());
4330 assert(!getAsyncOperand());
4331
4332 getAsyncOperandMutable().append(newValue);
4333}
4334
4335void EnterDataOp::addWaitOnly(MLIRContext *context,
4336 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4337 assert(effectiveDeviceTypes.empty());
4338 assert(!getWaitAttr());
4339 assert(getWaitOperands().empty());
4340 assert(!getWaitDevnum());
4341
4342 setWaitAttr(mlir::UnitAttr::get(context));
4343}
4344
4345void EnterDataOp::addWaitOperands(
4346 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4347 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4348 assert(effectiveDeviceTypes.empty());
4349 assert(!getWaitAttr());
4350 assert(getWaitOperands().empty());
4351 assert(!getWaitDevnum());
4352
4353 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4354 // operands list.
4355 if (hasDevnum) {
4356 getWaitDevnumMutable().append(newValues.front());
4357 newValues = newValues.drop_front();
4358 }
4359
4360 getWaitOperandsMutable().append(newValues);
4361}
4362
4363//===----------------------------------------------------------------------===//
4364// AtomicReadOp
4365//===----------------------------------------------------------------------===//
4366
4367LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
4368
4369//===----------------------------------------------------------------------===//
4370// AtomicWriteOp
4371//===----------------------------------------------------------------------===//
4372
4373LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
4374
4375//===----------------------------------------------------------------------===//
4376// AtomicUpdateOp
4377//===----------------------------------------------------------------------===//
4378
4379LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4380 PatternRewriter &rewriter) {
4381 if (op.isNoOp()) {
4382 rewriter.eraseOp(op);
4383 return success();
4384 }
4385
4386 if (Value writeVal = op.getWriteOpVal()) {
4387 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
4388 op.getIfCond());
4389 return success();
4390 }
4391
4392 return failure();
4393}
4394
4395LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
4396
4397LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4398
4399//===----------------------------------------------------------------------===//
4400// AtomicCaptureOp
4401//===----------------------------------------------------------------------===//
4402
4403AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4404 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4405 return op;
4406 return dyn_cast<AtomicReadOp>(getSecondOp());
4407}
4408
4409AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4410 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4411 return op;
4412 return dyn_cast<AtomicWriteOp>(getSecondOp());
4413}
4414
4415AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4416 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4417 return op;
4418 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4419}
4420
4421LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
4422
4423//===----------------------------------------------------------------------===//
4424// DeclareEnterOp
4425//===----------------------------------------------------------------------===//
4426
4427template <typename Op>
4428static LogicalResult
4430 bool requireAtLeastOneOperand = true) {
4431 if (operands.empty() && requireAtLeastOneOperand)
4432 return emitError(
4433 op->getLoc(),
4434 "at least one operand must appear on the declare operation");
4435
4436 for (mlir::Value operand : operands) {
4437 if (isa<BlockArgument>(operand) ||
4438 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4439 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4440 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4441 operand.getDefiningOp()))
4442 return op.emitError(
4443 "expect valid declare data entry operation or acc.getdeviceptr "
4444 "as defining op");
4445
4446 mlir::Value var{getVar(operand.getDefiningOp())};
4447 assert(var && "declare operands can only be data entry operations which "
4448 "must have var");
4449 (void)var;
4450 std::optional<mlir::acc::DataClause> dataClauseOptional{
4451 getDataClause(operand.getDefiningOp())};
4452 assert(dataClauseOptional.has_value() &&
4453 "declare operands can only be data entry operations which must have "
4454 "dataClause");
4455 (void)dataClauseOptional;
4456 }
4457
4458 return success();
4459}
4460
4461LogicalResult acc::DeclareEnterOp::verify() {
4462 return checkDeclareOperands(*this, this->getDataClauseOperands());
4463}
4464
4465//===----------------------------------------------------------------------===//
4466// DeclareExitOp
4467//===----------------------------------------------------------------------===//
4468
4469LogicalResult acc::DeclareExitOp::verify() {
4470 if (getToken())
4471 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4472 /*requireAtLeastOneOperand=*/false);
4473 return checkDeclareOperands(*this, this->getDataClauseOperands());
4474}
4475
4476//===----------------------------------------------------------------------===//
4477// DeclareOp
4478//===----------------------------------------------------------------------===//
4479
4480LogicalResult acc::DeclareOp::verify() {
4481 return checkDeclareOperands(*this, this->getDataClauseOperands());
4482}
4483
4484//===----------------------------------------------------------------------===//
4485// RoutineOp
4486//===----------------------------------------------------------------------===//
4487
4488static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4489 acc::DeviceType dtype) {
4490 unsigned parallelism = 0;
4491 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4492 parallelism += op.hasWorker(dtype) ? 1 : 0;
4493 parallelism += op.hasVector(dtype) ? 1 : 0;
4494 parallelism += op.hasSeq(dtype) ? 1 : 0;
4495 return parallelism;
4496}
4497
4498LogicalResult acc::RoutineOp::verify() {
4499 unsigned baseParallelism =
4500 getParallelismForDeviceType(*this, acc::DeviceType::None);
4501
4502 if (baseParallelism > 1)
4503 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4504 "be present at the same time";
4505
4506 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4507 ++dtypeInt) {
4508 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4509 if (dtype == acc::DeviceType::None)
4510 continue;
4511 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4512
4513 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4514 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4515 "be present at the same time for device_type `"
4516 << acc::stringifyDeviceType(dtype) << "`";
4517 }
4518
4519 return success();
4520}
4521
4522static ParseResult parseBindName(OpAsmParser &parser,
4523 mlir::ArrayAttr &bindIdName,
4524 mlir::ArrayAttr &bindStrName,
4525 mlir::ArrayAttr &deviceIdTypes,
4526 mlir::ArrayAttr &deviceStrTypes) {
4527 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4528 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4529 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4530 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4531
4532 if (failed(parser.parseCommaSeparatedList([&]() {
4533 mlir::Attribute newAttr;
4534 bool isSymbolRefAttr;
4535 auto parseResult = parser.parseAttribute(newAttr);
4536 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4537 bindIdNameAttrs.push_back(symbolRefAttr);
4538 isSymbolRefAttr = true;
4539 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4540 bindStrNameAttrs.push_back(stringAttr);
4541 isSymbolRefAttr = false;
4542 }
4543 if (parseResult)
4544 return failure();
4545 if (failed(parser.parseOptionalLSquare())) {
4546 if (isSymbolRefAttr) {
4547 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4548 parser.getContext(), mlir::acc::DeviceType::None));
4549 } else {
4550 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4551 parser.getContext(), mlir::acc::DeviceType::None));
4552 }
4553 } else {
4554 if (isSymbolRefAttr) {
4555 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4556 parser.parseRSquare())
4557 return failure();
4558 } else {
4559 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4560 parser.parseRSquare())
4561 return failure();
4562 }
4563 }
4564 return success();
4565 })))
4566 return failure();
4567
4568 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4569 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4570 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4571 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4572
4573 return success();
4574}
4575
4577 std::optional<mlir::ArrayAttr> bindIdName,
4578 std::optional<mlir::ArrayAttr> bindStrName,
4579 std::optional<mlir::ArrayAttr> deviceIdTypes,
4580 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4581 // Create combined vectors for all bind names and device types
4584
4585 // Append bindIdName and deviceIdTypes
4586 if (hasDeviceTypeValues(deviceIdTypes)) {
4587 allBindNames.append(bindIdName->begin(), bindIdName->end());
4588 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4589 }
4590
4591 // Append bindStrName and deviceStrTypes
4592 if (hasDeviceTypeValues(deviceStrTypes)) {
4593 allBindNames.append(bindStrName->begin(), bindStrName->end());
4594 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4595 }
4596
4597 // Print the combined sequence
4598 if (!allBindNames.empty())
4599 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4600 [&](const auto &pair) {
4601 p << std::get<0>(pair);
4602 printSingleDeviceType(p, std::get<1>(pair));
4603 });
4604}
4605
4606static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4607 mlir::ArrayAttr &gang,
4608 mlir::ArrayAttr &gangDim,
4609 mlir::ArrayAttr &gangDimDeviceTypes) {
4610
4611 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4612 gangDimDeviceTypeAttrs;
4613 bool needCommaBeforeOperands = false;
4614
4615 // Gang keyword only
4616 if (failed(parser.parseOptionalLParen())) {
4617 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4618 parser.getContext(), mlir::acc::DeviceType::None));
4619 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4620 return success();
4621 }
4622
4623 // Parse keyword only attributes
4624 if (succeeded(parser.parseOptionalLSquare())) {
4625 if (failed(parser.parseCommaSeparatedList([&]() {
4626 if (parser.parseAttribute(gangAttrs.emplace_back()))
4627 return failure();
4628 return success();
4629 })))
4630 return failure();
4631 if (parser.parseRSquare())
4632 return failure();
4633 needCommaBeforeOperands = true;
4634 }
4635
4636 if (needCommaBeforeOperands && failed(parser.parseComma()))
4637 return failure();
4638
4639 if (failed(parser.parseCommaSeparatedList([&]() {
4640 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4641 parser.parseColon() ||
4642 parser.parseAttribute(gangDimAttrs.emplace_back()))
4643 return failure();
4644 if (succeeded(parser.parseOptionalLSquare())) {
4645 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4646 parser.parseRSquare())
4647 return failure();
4648 } else {
4649 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4650 parser.getContext(), mlir::acc::DeviceType::None));
4651 }
4652 return success();
4653 })))
4654 return failure();
4655
4656 if (failed(parser.parseRParen()))
4657 return failure();
4658
4659 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4660 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4661 gangDimDeviceTypes =
4662 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4663
4664 return success();
4665}
4666
4668 std::optional<mlir::ArrayAttr> gang,
4669 std::optional<mlir::ArrayAttr> gangDim,
4670 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4671
4672 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4673 gang->size() == 1) {
4674 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4675 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4676 return;
4677 }
4678
4679 p << "(";
4680
4681 printDeviceTypes(p, gang);
4682
4683 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4684 p << ", ";
4685
4686 if (hasDeviceTypeValues(gangDimDeviceTypes))
4687 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4688 [&](const auto &pair) {
4689 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4690 p << std::get<0>(pair);
4691 printSingleDeviceType(p, std::get<1>(pair));
4692 });
4693
4694 p << ")";
4695}
4696
4697static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4698 mlir::ArrayAttr &deviceTypes) {
4700 // Keyword only
4701 if (failed(parser.parseOptionalLParen())) {
4702 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4703 parser.getContext(), mlir::acc::DeviceType::None));
4704 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4705 return success();
4706 }
4707
4708 // Parse device type attributes
4709 if (succeeded(parser.parseOptionalLSquare())) {
4710 if (failed(parser.parseCommaSeparatedList([&]() {
4711 if (parser.parseAttribute(attributes.emplace_back()))
4712 return failure();
4713 return success();
4714 })))
4715 return failure();
4716 if (parser.parseRSquare() || parser.parseRParen())
4717 return failure();
4718 }
4719 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4720 return success();
4721}
4722
4723static void
4725 std::optional<mlir::ArrayAttr> deviceTypes) {
4726
4727 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4728 auto deviceTypeAttr =
4729 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4730 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4731 return;
4732 }
4733
4734 if (!hasDeviceTypeValues(deviceTypes))
4735 return;
4736
4737 p << "([";
4738 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4739 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4740 p << dTypeAttr;
4741 });
4742 p << "])";
4743}
4744
4745bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4746
4747bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4748 return hasDeviceType(getWorker(), deviceType);
4749}
4750
4751bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4752
4753bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4754 return hasDeviceType(getVector(), deviceType);
4755}
4756
4757bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4758
4759bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4760 return hasDeviceType(getSeq(), deviceType);
4761}
4762
4763std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4764RoutineOp::getBindNameValue() {
4765 return getBindNameValue(mlir::acc::DeviceType::None);
4766}
4767
4768std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4769RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4770 if (hasDeviceTypeValues(getBindIdNameDeviceType())) {
4771 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4772 auto attr = (*getBindIdName())[*pos];
4773 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4774 assert(symbolRefAttr && "expected SymbolRef");
4775 return symbolRefAttr;
4776 }
4777 }
4778
4779 if (hasDeviceTypeValues(getBindStrNameDeviceType())) {
4780 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4781 auto attr = (*getBindStrName())[*pos];
4782 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4783 assert(stringAttr && "expected String");
4784 return stringAttr;
4785 }
4786 }
4787
4788 return std::nullopt;
4789}
4790
4791bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4792
4793bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4794 return hasDeviceType(getGang(), deviceType);
4795}
4796
4797std::optional<int64_t> RoutineOp::getGangDimValue() {
4798 return getGangDimValue(mlir::acc::DeviceType::None);
4799}
4800
4801std::optional<int64_t>
4802RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4803 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4804 return std::nullopt;
4805 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4806 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4807 return intAttr.getInt();
4808 }
4809 return std::nullopt;
4810}
4811
4812void RoutineOp::addSeq(MLIRContext *context,
4813 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4814 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4815 effectiveDeviceTypes));
4816}
4817
4818void RoutineOp::addVector(MLIRContext *context,
4819 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4820 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4821 effectiveDeviceTypes));
4822}
4823
4824void RoutineOp::addWorker(MLIRContext *context,
4825 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4826 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4827 effectiveDeviceTypes));
4828}
4829
4830void RoutineOp::addGang(MLIRContext *context,
4831 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4832 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4833 effectiveDeviceTypes));
4834}
4835
4836void RoutineOp::addGang(MLIRContext *context,
4837 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4838 uint64_t val) {
4841
4842 if (getGangDimAttr())
4843 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4844 if (getGangDimDeviceTypeAttr())
4845 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4846
4847 assert(dimValues.size() == deviceTypes.size());
4848
4849 if (effectiveDeviceTypes.empty()) {
4850 dimValues.push_back(
4851 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4852 deviceTypes.push_back(
4853 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4854 } else {
4855 for (DeviceType dt : effectiveDeviceTypes) {
4856 dimValues.push_back(
4857 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4858 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4859 }
4860 }
4861 assert(dimValues.size() == deviceTypes.size());
4862
4863 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4864 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4865}
4866
4867void RoutineOp::addBindStrName(MLIRContext *context,
4868 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4869 mlir::StringAttr val) {
4870 unsigned before = getBindStrNameDeviceTypeAttr()
4871 ? getBindStrNameDeviceTypeAttr().size()
4872 : 0;
4873
4874 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4875 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4876 unsigned after = getBindStrNameDeviceTypeAttr().size();
4877
4879 if (getBindStrNameAttr())
4880 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4881 for (unsigned i = 0; i < after - before; ++i)
4882 vals.push_back(val);
4883
4884 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4885}
4886
4887void RoutineOp::addBindIDName(MLIRContext *context,
4888 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4889 mlir::SymbolRefAttr val) {
4890 unsigned before =
4891 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4892
4893 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4894 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4895 unsigned after = getBindIdNameDeviceTypeAttr().size();
4896
4898 if (getBindIdNameAttr())
4899 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4900 for (unsigned i = 0; i < after - before; ++i)
4901 vals.push_back(val);
4902
4903 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4904}
4905
4906//===----------------------------------------------------------------------===//
4907// InitOp
4908//===----------------------------------------------------------------------===//
4909
4910LogicalResult acc::InitOp::verify() {
4911 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4912 return emitOpError("cannot be nested in a compute operation");
4913 return success();
4914}
4915
4916void acc::InitOp::addDeviceType(MLIRContext *context,
4917 mlir::acc::DeviceType deviceType) {
4919 if (getDeviceTypesAttr())
4920 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4921
4922 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4923 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4924}
4925
4926//===----------------------------------------------------------------------===//
4927// ShutdownOp
4928//===----------------------------------------------------------------------===//
4929
4930LogicalResult acc::ShutdownOp::verify() {
4931 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4932 return emitOpError("cannot be nested in a compute operation");
4933 return success();
4934}
4935
4936void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4937 mlir::acc::DeviceType deviceType) {
4939 if (getDeviceTypesAttr())
4940 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4941
4942 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4943 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4944}
4945
4946//===----------------------------------------------------------------------===//
4947// SetOp
4948//===----------------------------------------------------------------------===//
4949
4950LogicalResult acc::SetOp::verify() {
4951 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4952 return emitOpError("cannot be nested in a compute operation");
4953 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4954 return emitOpError("at least one default_async, device_num, or device_type "
4955 "operand must appear");
4956 return success();
4957}
4958
4959//===----------------------------------------------------------------------===//
4960// UpdateOp
4961//===----------------------------------------------------------------------===//
4962
4963LogicalResult acc::UpdateOp::verify() {
4964 // At least one of host or device should have a value.
4965 if (getDataClauseOperands().empty())
4966 return emitError("at least one value must be present in dataOperands");
4967
4969 getAsyncOperandsDeviceTypeAttr(),
4970 "async")))
4971 return failure();
4972
4974 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4975 getWaitOperandsDeviceTypeAttr(), "wait")))
4976 return failure();
4977
4979 return failure();
4980
4981 for (mlir::Value operand : getDataClauseOperands())
4982 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4983 operand.getDefiningOp()))
4984 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4985 "as defining op");
4986
4987 return success();
4988}
4989
4990unsigned UpdateOp::getNumDataOperands() {
4991 return getDataClauseOperands().size();
4992}
4993
4994Value UpdateOp::getDataOperand(unsigned i) {
4995 unsigned numOptional = getAsyncOperands().size();
4996 numOptional += getIfCond() ? 1 : 0;
4997 return getOperand(getWaitOperands().size() + numOptional + i);
4998}
4999
5000void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
5001 MLIRContext *context) {
5002 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
5003}
5004
5005bool UpdateOp::hasAsyncOnly() {
5006 return hasAsyncOnly(mlir::acc::DeviceType::None);
5007}
5008
5009bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
5010 return hasDeviceType(getAsyncOnly(), deviceType);
5011}
5012
5013mlir::Value UpdateOp::getAsyncValue() {
5014 return getAsyncValue(mlir::acc::DeviceType::None);
5015}
5016
5017mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
5019 return {};
5020
5021 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
5022 return getAsyncOperands()[*pos];
5023
5024 return {};
5025}
5026
5027bool UpdateOp::hasWaitOnly() {
5028 return hasWaitOnly(mlir::acc::DeviceType::None);
5029}
5030
5031bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
5032 return hasDeviceType(getWaitOnly(), deviceType);
5033}
5034
5035mlir::Operation::operand_range UpdateOp::getWaitValues() {
5036 return getWaitValues(mlir::acc::DeviceType::None);
5037}
5038
5040UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
5042 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
5043 getHasWaitDevnum(), deviceType);
5044}
5045
5046mlir::Value UpdateOp::getWaitDevnum() {
5047 return getWaitDevnum(mlir::acc::DeviceType::None);
5048}
5049
5050mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
5051 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
5052 getWaitOperandsSegments(), getHasWaitDevnum(),
5053 deviceType);
5054}
5055
5056void UpdateOp::addAsyncOnly(MLIRContext *context,
5057 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5058 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
5059 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
5060}
5061
5062void UpdateOp::addAsyncOperand(
5063 MLIRContext *context, mlir::Value newValue,
5064 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5065 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5066 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
5067 getAsyncOperandsMutable()));
5068}
5069
5070void UpdateOp::addWaitOnly(MLIRContext *context,
5071 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5072 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
5073 effectiveDeviceTypes));
5074}
5075
5076void UpdateOp::addWaitOperands(
5077 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
5078 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5079
5081 if (getWaitOperandsSegments())
5082 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
5083
5084 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5085 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
5086 getWaitOperandsMutable(), segments));
5087 setWaitOperandsSegments(segments);
5088
5090 if (getHasWaitDevnumAttr())
5091 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5092 hasDevnums.insert(
5093 hasDevnums.end(),
5094 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
5095 mlir::BoolAttr::get(context, hasDevnum));
5096 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5097}
5098
5099//===----------------------------------------------------------------------===//
5100// WaitOp
5101//===----------------------------------------------------------------------===//
5102
5103LogicalResult acc::WaitOp::verify() {
5104 // The async attribute represent the async clause without value. Therefore the
5105 // attribute and operand cannot appear at the same time.
5106 if (getAsyncOperand() && getAsync())
5107 return emitError("async attribute cannot appear with asyncOperand");
5108
5109 if (getWaitDevnum() && getWaitOperands().empty())
5110 return emitError("wait_devnum cannot appear without waitOperands");
5111
5112 return success();
5113}
5114
5115#define GET_OP_CLASSES
5116#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5117
5118#define GET_ATTRDEF_CLASSES
5119#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5120
5121#define GET_TYPEDEF_CLASSES
5122#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5123
5124//===----------------------------------------------------------------------===//
5125// acc dialect utilities
5126//===----------------------------------------------------------------------===//
5127
5130 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
5132 accDataClauseOp)
5133 .Case<ACC_DATA_ENTRY_OPS>(
5134 [&](auto entry) { return entry.getVarPtr(); })
5135 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5136 [&](auto exit) { return exit.getVarPtr(); })
5137 .Default([&](mlir::Operation *) {
5139 })};
5140 return varPtr;
5141}
5142
5144 auto varPtr{
5146 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
5147 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5148 return varPtr;
5149}
5150
5152 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
5153 .Case<ACC_DATA_ENTRY_OPS>(
5154 [&](auto entry) { return entry.getVarType(); })
5155 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5156 [&](auto exit) { return exit.getVarType(); })
5157 .Default([&](mlir::Operation *) { return mlir::Type(); })};
5158 return varType;
5159}
5160
5163 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
5165 accDataClauseOp)
5166 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5167 [&](auto dataClause) { return dataClause.getAccPtr(); })
5168 .Default([&](mlir::Operation *) {
5170 })};
5171 return accPtr;
5172}
5173
5175 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
5177 [&](auto dataClause) { return dataClause.getAccVar(); })
5178 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5179 return accPtr;
5180}
5181
5183 auto varPtrPtr{
5185 .Case<ACC_DATA_ENTRY_OPS>(
5186 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
5187 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5188 return varPtrPtr;
5189}
5190
5195 accDataClauseOp)
5196 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5198 dataClause.getBounds().begin(), dataClause.getBounds().end());
5199 })
5200 .Default([&](mlir::Operation *) {
5202 })};
5203 return bounds;
5204}
5205
5209 accDataClauseOp)
5210 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5212 dataClause.getAsyncOperands().begin(),
5213 dataClause.getAsyncOperands().end());
5214 })
5215 .Default([&](mlir::Operation *) {
5217 });
5218}
5219
5220mlir::ArrayAttr
5223 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5224 return dataClause.getAsyncOperandsDeviceTypeAttr();
5225 })
5226 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5227}
5228
5229mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
5232 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
5233 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5234}
5235
5236std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
5237 auto name{
5239 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
5240 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
5241 return {};
5242 })};
5243 return name;
5244}
5245
5246std::optional<mlir::acc::DataClause>
5248 auto dataClause{
5250 accDataEntryOp)
5251 .Case<ACC_DATA_ENTRY_OPS>(
5252 [&](auto entry) { return entry.getDataClause(); })
5253 .Default([&](mlir::Operation *) { return std::nullopt; })};
5254 return dataClause;
5255}
5256
5258 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
5259 .Case<ACC_DATA_ENTRY_OPS>(
5260 [&](auto entry) { return entry.getImplicit(); })
5261 .Default([&](mlir::Operation *) { return false; })};
5262 return implicit;
5263}
5264
5266 auto dataOperands{
5269 [&](auto entry) { return entry.getDataClauseOperands(); })
5270 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
5271 return dataOperands;
5272}
5273
5276 auto dataOperands{
5279 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
5280 .Default([&](mlir::Operation *) { return nullptr; })};
5281 return dataOperands;
5282}
5283
5284mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
5285 auto recipe{
5287 .Case<ACC_DATA_ENTRY_OPS>(
5288 [&](auto entry) { return entry.getRecipeAttr(); })
5289 .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
5290 return recipe;
5291}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4667
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1530
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:3418
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:2036
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4522
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
Definition OpenACC.cpp:897
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:670
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2533
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
Definition OpenACC.cpp:890
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:823
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:654
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:792
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2544
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2449
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:596
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4724
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:3227
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2784
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:3434
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:4429
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:730
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2738
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:614
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:712
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:746
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1295
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3802
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1991
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2580
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:2119
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1305
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:738
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:801
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:625
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:638
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2319
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:493
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:777
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3833
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:504
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4697
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4606
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2432
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2607
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2723
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2386
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
Definition OpenACC.cpp:1289
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2699
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:867
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:3246
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1762
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2768
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:2363
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
Definition OpenACC.cpp:756
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
Definition OpenACC.cpp:2005
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2680
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:600
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:3373
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2618
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:835
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:690
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:2047
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4488
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2369
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2804
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4576
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
false
Parses a map_entries map type from a string format back into its numeric value.
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
static BoolAttr get(MLIRContext *context, bool value)
IntegerType getI64Type()
Definition Builders.cpp:69
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:119
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:157
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OperandRange operand_range
Definition Operation.h:397
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:608
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
result_range getResults()
Definition Operation.h:441
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
Base attribute class for language-specific variable information carried through the OpenACC type inte...
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:369
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
#define ACC_COMPUTE_CONSTRUCT_OPS
Definition OpenACC.h:62
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:73
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:48
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:58
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5174
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5143
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5162
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5247
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:5275
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:5192
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:5265
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:5236
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:5257
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
Definition OpenACC.cpp:5284
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:5207
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
Definition OpenACC.h:171
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:5182
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:208
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5229
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5151
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5129
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5221
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Value genCast(OpBuilder &builder, Location loc, Value value, Type dstTy)
Add type casting between arith and index types when needed.
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.