MLIR 23.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/IR/Matchers.h"
22#include "mlir/IR/SymbolTable.h"
23#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
28#include <variant>
29
30using namespace mlir;
31using namespace acc;
32
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
38
39namespace {
40
41static bool isScalarLikeType(Type type) {
42 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
43}
44
45/// Helper function to attach the `VarName` attribute to an operation
46/// if a variable name is provided.
47static void attachVarNameAttr(Operation *op, OpBuilder &builder,
48 StringRef varName) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
51 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
52 }
53}
54
55template <typename T>
56struct MemRefPointerLikeModel
57 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
58 Type getElementType(Type pointer) const {
59 return cast<T>(pointer).getElementType();
60 }
61
62 mlir::acc::VariableTypeCategory
63 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
64 Type varType) const {
65 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
67 }
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
70 // This memref is unranked - aka it could have any rank, including a
71 // rank of 0 which could mean scalar. For now, return uncategorized.
72 return mlir::acc::VariableTypeCategory::uncategorized;
73 }
74
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
78 }
79 // Zero-rank non-scalar - need further analysis to determine the type
80 // category. For now, return uncategorized.
81 return mlir::acc::VariableTypeCategory::uncategorized;
82 }
83
84 // It has a rank - must be an array.
85 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
87 }
88
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree) const {
92 auto memrefTy = cast<MemRefType>(pointer);
93
94 // Check if this is a static memref (all dimensions are known) - if yes
95 // then we can generate an alloca operation.
96 if (memrefTy.hasStaticShape()) {
97 needsFree = false; // alloca doesn't need deallocation
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
101 }
102
103 // For dynamic memrefs, extract sizes from the original variable if
104 // provided. Otherwise they cannot be handled.
105 if (originalVar && originalVar.getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
110 // Extract the size of dimension i from the original variable
111 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
112 auto dimSize =
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
115 }
116 // Note: We only add dynamic sizes to the dynamicSizes array
117 // Static dimensions are handled automatically by AllocOp
118 }
119 needsFree = true; // alloc needs deallocation
120 auto allocOp =
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
124 }
125
126 // TODO: Unranked not yet supported.
127 return {};
128 }
129
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
131 TypedValue<PointerLikeType> varToFree, Value allocRes,
132 Type varType) const {
133 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
134 // Use allocRes if provided to determine the allocation type
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
136
137 // Walk through casts to find the original allocation
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc = nullptr;
140
141 // Follow the chain of operations to find the original allocation
142 // even if a casted result is provided.
143 while (currentValue) {
144 if (auto *definingOp = currentValue.getDefiningOp()) {
145 // Check if this is an allocation operation
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
148 break;
149 }
150
151 // Check if this is a cast operation we can look through
152 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
154 continue;
155 }
156
157 // Check for other cast-like operations
158 if (auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
161 continue;
162 }
163
164 // If we can't look through this operation, stop
165 break;
166 }
167 // This is a block argument or similar - can't trace further.
168 break;
169 }
170
171 if (originalAlloc) {
172 if (isa<memref::AllocaOp>(originalAlloc)) {
173 // This is an alloca - no dealloc needed, but return true (success)
174 return true;
175 }
176 if (isa<memref::AllocOp>(originalAlloc)) {
177 // This is an alloc - generate dealloc on varToFree
178 memref::DeallocOp::create(builder, loc, memrefValue);
179 return true;
180 }
181 }
182 }
183
184 return false;
185 }
186
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
188 TypedValue<PointerLikeType> destination,
189 TypedValue<PointerLikeType> source, Type varType) const {
190 // Generate a copy operation between two memrefs
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
193
194 // As per memref documentation, source and destination must have same
195 // element type and shape in order to be compatible. We do not want to fail
196 // with an IR verification error - thus check that before generating the
197 // copy operation.
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
203 return true;
204 }
205
206 return false;
207 }
208
209 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType) const {
212 // Load from a memref - only valid for scalar memrefs (rank 0).
213 // This is because the address computation for memrefs is part of the load
214 // (and not computed separately), but the API does not have arguments for
215 // indexing.
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
217 if (!memrefValue)
218 return {};
219
220 auto memrefTy = memrefValue.getType();
221
222 // Only load from scalar memrefs (rank 0)
223 if (memrefTy.getRank() != 0)
224 return {};
225
226 return memref::LoadOp::create(builder, loc, memrefValue);
227 }
228
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
230 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
231 // Store to a memref - only valid for scalar memrefs (rank 0)
232 // This is because the address computation for memrefs is part of the store
233 // (and not computed separately), but the API does not have arguments for
234 // indexing.
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
236 if (!memrefValue)
237 return false;
238
239 auto memrefTy = memrefValue.getType();
240
241 // Only store to scalar memrefs (rank 0)
242 if (memrefTy.getRank() != 0)
243 return false;
244
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
246 return true;
247 }
248
249 bool isDeviceData(Type pointer, Value var) const {
250 auto memrefTy = cast<T>(pointer);
251 Attribute memSpace = memrefTy.getMemorySpace();
252 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
253 }
254};
255
256struct LLVMPointerPointerLikeModel
257 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
258 LLVM::LLVMPointerType> {
259 Type getElementType(Type pointer) const { return Type(); }
260
261 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
263 Type valueType) const {
264 // For LLVM pointers, we need the valueType to determine what to load
265 if (!valueType)
266 return {};
267
268 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
269 }
270
271 bool genStore(Type pointer, OpBuilder &builder, Location loc,
272 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
273 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
274 return true;
275 }
276};
277
278struct MemrefAddressOfGlobalModel
279 : public AddressOfGlobalOpInterface::ExternalModel<
280 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
281 SymbolRefAttr getSymbol(Operation *op) const {
282 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
283 return getGlobalOp.getNameAttr();
284 }
285};
286
287struct MemrefGlobalVariableModel
288 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
289 memref::GlobalOp> {
290 bool isConstant(Operation *op) const {
291 auto globalOp = cast<memref::GlobalOp>(op);
292 return globalOp.getConstant();
293 }
294
295 Region *getInitRegion(Operation *op) const {
296 // GlobalOp uses attributes for initialization, not regions
297 return nullptr;
298 }
299
300 bool isDeviceData(Operation *op) const {
301 auto globalOp = cast<memref::GlobalOp>(op);
302 Attribute memSpace = globalOp.getType().getMemorySpace();
303 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
304 }
305};
306
307struct GPULaunchOffloadRegionModel
308 : public acc::OffloadRegionOpInterface::ExternalModel<
309 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
310 mlir::Region &getOffloadRegion(mlir::Operation *op) const {
311 return cast<gpu::LaunchOp>(op).getBody();
312 }
313};
314
315/// Helper function for any of the times we need to modify an ArrayAttr based on
316/// a device type list. Returns a new ArrayAttr with all of the
317/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
318/// list is empty).
319mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
320 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
321 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
323 if (existingDeviceTypes)
324 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
325
326 if (newDeviceTypes.empty())
327 deviceTypes.push_back(
328 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
329
330 for (DeviceType dt : newDeviceTypes)
331 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
332
333 return mlir::ArrayAttr::get(context, deviceTypes);
334}
335
336/// Helper function for any of the times we need to add operands that are
337/// affected by a device type list. Returns a new ArrayAttr with all of the
338/// existingDeviceTypes, plus the effective new ones (or an added none, if the
339/// new list is empty). Additionally, adds the arguments to the argCollection
340/// the correct number of times. This will also update a 'segments' array, even
341/// if it won't be used.
342mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
343 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
344 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
345 mlir::MutableOperandRange argCollection,
346 llvm::SmallVector<int32_t> &segments) {
348 if (existingDeviceTypes)
349 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
350
351 if (newDeviceTypes.empty()) {
352 argCollection.append(arguments);
353 segments.push_back(arguments.size());
354 deviceTypes.push_back(
355 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
356 }
357
358 for (DeviceType dt : newDeviceTypes) {
359 argCollection.append(arguments);
360 segments.push_back(arguments.size());
361 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
362 }
363
364 return mlir::ArrayAttr::get(context, deviceTypes);
365}
366
367/// Overload for when the 'segments' aren't needed.
368mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
369 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
370 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
371 mlir::MutableOperandRange argCollection) {
373 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
374 newDeviceTypes, arguments,
375 argCollection, segments);
376}
377} // namespace
378
379//===----------------------------------------------------------------------===//
380// OpenACC operations
381//===----------------------------------------------------------------------===//
382
383void OpenACCDialect::initialize() {
384 addOperations<
385#define GET_OP_LIST
386#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
387 >();
388 addAttributes<
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
391 >();
392 addTypes<
393#define GET_TYPEDEF_LIST
394#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
395 >();
396
397 // By attaching interfaces here, we make the OpenACC dialect dependent on
398 // the other dialects. This is probably better than having dialects like LLVM
399 // and memref be dependent on OpenACC.
400 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
401 *getContext());
402 UnrankedMemRefType::attachInterface<
403 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
404 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
405 *getContext());
406
407 // Attach operation interfaces
408 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
409 *getContext());
410 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
411 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext());
412}
413
414//===----------------------------------------------------------------------===//
415// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
416// acc.kernel_environment / acc.data / acc.host_data / acc.loop
417//===----------------------------------------------------------------------===//
418
419/// Generic helper for single-region OpenACC ops that execute their body once
420/// and then return to the parent operation with their results (if any).
421static void
423 RegionBranchPoint point,
425 if (point.isParent()) {
426 regions.push_back(RegionSuccessor(&region));
427 return;
428 }
429
430 regions.push_back(RegionSuccessor::parent());
431}
432
434 RegionSuccessor successor) {
435 return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
436}
437
438void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
440 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
441 regions);
442}
443
444ValueRange KernelsOp::getSuccessorInputs(RegionSuccessor successor) {
445 return getSingleRegionSuccessorInputs(getOperation(), successor);
446}
447
448void ParallelOp::getSuccessorRegions(
450 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
451 regions);
452}
453
454ValueRange ParallelOp::getSuccessorInputs(RegionSuccessor successor) {
455 return getSingleRegionSuccessorInputs(getOperation(), successor);
456}
457
458void SerialOp::getSuccessorRegions(RegionBranchPoint point,
460 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
461 regions);
462}
463
464ValueRange SerialOp::getSuccessorInputs(RegionSuccessor successor) {
465 return getSingleRegionSuccessorInputs(getOperation(), successor);
466}
467
468void KernelEnvironmentOp::getSuccessorRegions(
470 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
471 regions);
472}
473
474ValueRange KernelEnvironmentOp::getSuccessorInputs(RegionSuccessor successor) {
475 return getSingleRegionSuccessorInputs(getOperation(), successor);
476}
477
478void DataOp::getSuccessorRegions(RegionBranchPoint point,
480 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
481 regions);
482}
483
484ValueRange DataOp::getSuccessorInputs(RegionSuccessor successor) {
485 return getSingleRegionSuccessorInputs(getOperation(), successor);
486}
487
488void HostDataOp::getSuccessorRegions(
490 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
491 regions);
492}
493
494ValueRange HostDataOp::getSuccessorInputs(RegionSuccessor successor) {
495 return getSingleRegionSuccessorInputs(getOperation(), successor);
496}
497
498void LoopOp::getSuccessorRegions(RegionBranchPoint point,
500 // Unstructured loops: the body may contain arbitrary CFG and early exits.
501 // At the RegionBranch level, only model entry into the body and exit to the
502 // parent; any backedges are represented inside the region CFG.
503 if (getUnstructured()) {
504 if (point.isParent()) {
505 regions.push_back(RegionSuccessor(&getRegion()));
506 return;
507 }
508 regions.push_back(RegionSuccessor::parent());
509 return;
510 }
511
512 // Structured loops: model a loop-shaped region graph similar to scf.for.
513 regions.push_back(RegionSuccessor(&getRegion()));
514 regions.push_back(RegionSuccessor::parent());
515}
516
517ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) {
518 return getSingleRegionSuccessorInputs(getOperation(), successor);
519}
520
521//===----------------------------------------------------------------------===//
522// RegionBranchTerminatorOpInterface
523//===----------------------------------------------------------------------===//
524
526TerminatorOp::getMutableSuccessorOperands(RegionSuccessor /*point*/) {
527 // `acc.terminator` does not forward operands.
528 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
529}
530
531//===----------------------------------------------------------------------===//
532// device_type support helpers
533//===----------------------------------------------------------------------===//
534
535static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
536 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
537}
538
539static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
540 mlir::acc::DeviceType deviceType) {
541 if (!hasDeviceTypeValues(arrayAttr))
542 return false;
543
544 for (auto attr : *arrayAttr) {
545 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
546 if (deviceTypeAttr.getValue() == deviceType)
547 return true;
548 }
549
550 return false;
551}
552
554 std::optional<mlir::ArrayAttr> deviceTypes) {
555 if (!hasDeviceTypeValues(deviceTypes))
556 return;
557
558 p << "[";
559 llvm::interleaveComma(*deviceTypes, p,
560 [&](mlir::Attribute attr) { p << attr; });
561 p << "]";
562}
563
564static std::optional<unsigned> findSegment(ArrayAttr segments,
565 mlir::acc::DeviceType deviceType) {
566 unsigned segmentIdx = 0;
567 for (auto attr : segments) {
568 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
569 if (deviceTypeAttr.getValue() == deviceType)
570 return std::make_optional(segmentIdx);
571 ++segmentIdx;
572 }
573 return std::nullopt;
574}
575
577getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
579 std::optional<llvm::ArrayRef<int32_t>> segments,
580 mlir::acc::DeviceType deviceType) {
581 if (!arrayAttr)
582 return range.take_front(0);
583 if (auto pos = findSegment(*arrayAttr, deviceType)) {
584 int32_t nbOperandsBefore = 0;
585 for (unsigned i = 0; i < *pos; ++i)
586 nbOperandsBefore += (*segments)[i];
587 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
588 }
589 return range.take_front(0);
590}
591
592static mlir::Value
593getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
595 std::optional<llvm::ArrayRef<int32_t>> segments,
596 std::optional<mlir::ArrayAttr> hasWaitDevnum,
597 mlir::acc::DeviceType deviceType) {
598 if (!hasDeviceTypeValues(deviceTypeAttr))
599 return {};
600 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
601 if (hasWaitDevnum->getValue()[*pos])
602 return getValuesFromSegments(deviceTypeAttr, operands, segments,
603 deviceType)
604 .front();
605 return {};
606}
607
609getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
611 std::optional<llvm::ArrayRef<int32_t>> segments,
612 std::optional<mlir::ArrayAttr> hasWaitDevnum,
613 mlir::acc::DeviceType deviceType) {
614 auto range =
615 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
616 if (range.empty())
617 return range;
618 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
619 if (hasWaitDevnum && *hasWaitDevnum) {
620 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
621 if (boolAttr.getValue())
622 return range.drop_front(1); // first value is devnum
623 }
624 }
625 return range;
626}
627
628template <typename Op>
629static LogicalResult checkWaitAndAsyncConflict(Op op) {
630 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
631 ++dtypeInt) {
632 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
633
634 // The asyncOnly attribute represent the async clause without value.
635 // Therefore the attribute and operand cannot appear at the same time.
636 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
637 op.hasAsyncOnly(dtype))
638 return op.emitError(
639 "asyncOnly attribute cannot appear with asyncOperand");
640
641 // The wait attribute represent the wait clause without values. Therefore
642 // the attribute and operands cannot appear at the same time.
643 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
644 op.hasWaitOnly(dtype))
645 return op.emitError("wait attribute cannot appear with waitOperands");
646 }
647 return success();
648}
649
650template <typename Op>
651static LogicalResult checkVarAndVarType(Op op) {
652 if (!op.getVar())
653 return op.emitError("must have var operand");
654
655 // A variable must have a type that is either pointer-like or mappable.
656 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
657 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
658 return op.emitError("var must be mappable or pointer-like");
659
660 // When it is a pointer-like type, the varType must capture the target type.
661 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
662 op.getVarType() == op.getVar().getType())
663 return op.emitError("varType must capture the element type of var");
664
665 return success();
666}
667
668template <typename Op>
669static LogicalResult checkVarAndAccVar(Op op) {
670 if (op.getVar().getType() != op.getAccVar().getType())
671 return op.emitError("input and output types must match");
672
673 return success();
674}
675
676template <typename Op>
677static LogicalResult checkNoModifier(Op op) {
678 if (op.getModifiers() != acc::DataClauseModifier::none)
679 return op.emitError("no data clause modifiers are allowed");
680 return success();
681}
682
683template <typename Op>
684static LogicalResult
685checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
686 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
687 return op.emitError(
688 "invalid data clause modifiers: " +
689 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
690
691 return success();
692}
693
694template <typename OpT, typename RecipeOpT>
695static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
696 // Mappable types do not need a recipe because it is possible to generate one
697 // from its API. Reject reductions though because no API is available for them
698 // at this time.
699 if (mlir::acc::isMappableType(op.getVar().getType()) &&
700 !std::is_same_v<OpT, acc::ReductionOp>)
701 return success();
702
703 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
704 if (!operandRecipe)
705 return op->emitOpError() << "recipe expected for " << operandName;
706
707 auto decl =
709 if (!decl)
710 return op->emitOpError()
711 << "expected symbol reference " << operandRecipe << " to point to a "
712 << operandName << " declaration";
713 return success();
714}
715
716static ParseResult parseVar(mlir::OpAsmParser &parser,
718 // Either `var` or `varPtr` keyword is required.
719 if (failed(parser.parseOptionalKeyword("varPtr"))) {
720 if (failed(parser.parseKeyword("var")))
721 return failure();
722 }
723 if (failed(parser.parseLParen()))
724 return failure();
725 if (failed(parser.parseOperand(var)))
726 return failure();
727
728 return success();
729}
730
732 mlir::Value var) {
733 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
734 p << "varPtr(";
735 else
736 p << "var(";
737 p.printOperand(var);
738}
739
740static ParseResult parseAccVar(mlir::OpAsmParser &parser,
742 mlir::Type &accVarType) {
743 // Either `accVar` or `accPtr` keyword is required.
744 if (failed(parser.parseOptionalKeyword("accPtr"))) {
745 if (failed(parser.parseKeyword("accVar")))
746 return failure();
747 }
748 if (failed(parser.parseLParen()))
749 return failure();
750 if (failed(parser.parseOperand(var)))
751 return failure();
752 if (failed(parser.parseColon()))
753 return failure();
754 if (failed(parser.parseType(accVarType)))
755 return failure();
756 if (failed(parser.parseRParen()))
757 return failure();
758
759 return success();
760}
761
763 mlir::Value accVar, mlir::Type accVarType) {
764 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
765 p << "accPtr(";
766 else
767 p << "accVar(";
768 p.printOperand(accVar);
769 p << " : ";
770 p.printType(accVarType);
771 p << ")";
772}
773
774static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
775 mlir::Type &varPtrType,
776 mlir::TypeAttr &varTypeAttr) {
777 if (failed(parser.parseType(varPtrType)))
778 return failure();
779 if (failed(parser.parseRParen()))
780 return failure();
781
782 if (succeeded(parser.parseOptionalKeyword("varType"))) {
783 if (failed(parser.parseLParen()))
784 return failure();
785 mlir::Type varType;
786 if (failed(parser.parseType(varType)))
787 return failure();
788 varTypeAttr = mlir::TypeAttr::get(varType);
789 if (failed(parser.parseRParen()))
790 return failure();
791 } else {
792 // Set `varType` from the element type of the type of `varPtr`.
793 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
794 varTypeAttr = mlir::TypeAttr::get(
795 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
796 else
797 varTypeAttr = mlir::TypeAttr::get(varPtrType);
798 }
799
800 return success();
801}
802
804 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
805 p.printType(varPtrType);
806 p << ")";
807
808 // Print the `varType` only if it differs from the element type of
809 // `varPtr`'s type.
810 mlir::Type varType = varTypeAttr.getValue();
811 mlir::Type typeToCheckAgainst =
812 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
813 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
814 : varPtrType;
815 if (typeToCheckAgainst != varType) {
816 p << " varType(";
817 p.printType(varType);
818 p << ")";
819 }
820}
821
822static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
823 mlir::SymbolRefAttr &recipeAttr) {
824 if (failed(parser.parseAttribute(recipeAttr)))
825 return failure();
826 return success();
827}
828
830 mlir::SymbolRefAttr recipeAttr) {
831 p << recipeAttr;
832}
833
834//===----------------------------------------------------------------------===//
835// DataBoundsOp
836//===----------------------------------------------------------------------===//
837LogicalResult acc::DataBoundsOp::verify() {
838 auto extent = getExtent();
839 auto upperbound = getUpperbound();
840 if (!extent && !upperbound)
841 return emitError("expected extent or upperbound.");
842 return success();
843}
844
845//===----------------------------------------------------------------------===//
846// PrivateOp
847//===----------------------------------------------------------------------===//
848LogicalResult acc::PrivateOp::verify() {
849 if (getDataClause() != acc::DataClause::acc_private)
850 return emitError(
851 "data clause associated with private operation must match its intent");
852 if (failed(checkVarAndVarType(*this)))
853 return failure();
854 if (failed(checkNoModifier(*this)))
855 return failure();
856 if (failed(
858 return failure();
859 return success();
860}
861
862//===----------------------------------------------------------------------===//
863// FirstprivateOp
864//===----------------------------------------------------------------------===//
865LogicalResult acc::FirstprivateOp::verify() {
866 if (getDataClause() != acc::DataClause::acc_firstprivate)
867 return emitError("data clause associated with firstprivate operation must "
868 "match its intent");
869 if (failed(checkVarAndVarType(*this)))
870 return failure();
871 if (failed(checkNoModifier(*this)))
872 return failure();
874 *this, "firstprivate")))
875 return failure();
876 return success();
877}
878
879//===----------------------------------------------------------------------===//
880// FirstprivateMapInitialOp
881//===----------------------------------------------------------------------===//
882LogicalResult acc::FirstprivateMapInitialOp::verify() {
883 if (getDataClause() != acc::DataClause::acc_firstprivate)
884 return emitError("data clause associated with firstprivate operation must "
885 "match its intent");
886 if (failed(checkVarAndVarType(*this)))
887 return failure();
888 if (failed(checkNoModifier(*this)))
889 return failure();
890 return success();
891}
892
893//===----------------------------------------------------------------------===//
894// ReductionOp
895//===----------------------------------------------------------------------===//
896LogicalResult acc::ReductionOp::verify() {
897 if (getDataClause() != acc::DataClause::acc_reduction)
898 return emitError("data clause associated with reduction operation must "
899 "match its intent");
900 if (failed(checkVarAndVarType(*this)))
901 return failure();
902 if (failed(checkNoModifier(*this)))
903 return failure();
905 *this, "reduction")))
906 return failure();
907 return success();
908}
909
910//===----------------------------------------------------------------------===//
911// DevicePtrOp
912//===----------------------------------------------------------------------===//
913LogicalResult acc::DevicePtrOp::verify() {
914 if (getDataClause() != acc::DataClause::acc_deviceptr)
915 return emitError("data clause associated with deviceptr operation must "
916 "match its intent");
917 if (failed(checkVarAndVarType(*this)))
918 return failure();
919 if (failed(checkVarAndAccVar(*this)))
920 return failure();
921 if (failed(checkNoModifier(*this)))
922 return failure();
923 return success();
924}
925
926//===----------------------------------------------------------------------===//
927// PresentOp
928//===----------------------------------------------------------------------===//
929LogicalResult acc::PresentOp::verify() {
930 if (getDataClause() != acc::DataClause::acc_present)
931 return emitError(
932 "data clause associated with present operation must match its intent");
933 if (failed(checkVarAndVarType(*this)))
934 return failure();
935 if (failed(checkVarAndAccVar(*this)))
936 return failure();
937 if (failed(checkNoModifier(*this)))
938 return failure();
939 return success();
940}
941
942//===----------------------------------------------------------------------===//
943// CopyinOp
944//===----------------------------------------------------------------------===//
945LogicalResult acc::CopyinOp::verify() {
946 // Test for all clauses this operation can be decomposed from:
947 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
948 getDataClause() != acc::DataClause::acc_copyin_readonly &&
949 getDataClause() != acc::DataClause::acc_copy &&
950 getDataClause() != acc::DataClause::acc_reduction)
951 return emitError(
952 "data clause associated with copyin operation must match its intent"
953 " or specify original clause this operation was decomposed from");
954 if (failed(checkVarAndVarType(*this)))
955 return failure();
956 if (failed(checkVarAndAccVar(*this)))
957 return failure();
958 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
959 acc::DataClauseModifier::always |
960 acc::DataClauseModifier::capture)))
961 return failure();
962 return success();
963}
964
965bool acc::CopyinOp::isCopyinReadonly() {
966 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
967 acc::bitEnumContainsAny(getModifiers(),
968 acc::DataClauseModifier::readonly);
969}
970
971//===----------------------------------------------------------------------===//
972// CreateOp
973//===----------------------------------------------------------------------===//
974LogicalResult acc::CreateOp::verify() {
975 // Test for all clauses this operation can be decomposed from:
976 if (getDataClause() != acc::DataClause::acc_create &&
977 getDataClause() != acc::DataClause::acc_create_zero &&
978 getDataClause() != acc::DataClause::acc_copyout &&
979 getDataClause() != acc::DataClause::acc_copyout_zero)
980 return emitError(
981 "data clause associated with create operation must match its intent"
982 " or specify original clause this operation was decomposed from");
983 if (failed(checkVarAndVarType(*this)))
984 return failure();
985 if (failed(checkVarAndAccVar(*this)))
986 return failure();
987 // this op is the entry part of copyout, so it also needs to allow all
988 // modifiers allowed on copyout.
989 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
990 acc::DataClauseModifier::always |
991 acc::DataClauseModifier::capture)))
992 return failure();
993 return success();
994}
995
996bool acc::CreateOp::isCreateZero() {
997 // The zero modifier is encoded in the data clause.
998 return getDataClause() == acc::DataClause::acc_create_zero ||
999 getDataClause() == acc::DataClause::acc_copyout_zero ||
1000 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// NoCreateOp
1005//===----------------------------------------------------------------------===//
1006LogicalResult acc::NoCreateOp::verify() {
1007 if (getDataClause() != acc::DataClause::acc_no_create)
1008 return emitError("data clause associated with no_create operation must "
1009 "match its intent");
1010 if (failed(checkVarAndVarType(*this)))
1011 return failure();
1012 if (failed(checkVarAndAccVar(*this)))
1013 return failure();
1014 if (failed(checkNoModifier(*this)))
1015 return failure();
1016 return success();
1017}
1018
1019//===----------------------------------------------------------------------===//
1020// AttachOp
1021//===----------------------------------------------------------------------===//
1022LogicalResult acc::AttachOp::verify() {
1023 if (getDataClause() != acc::DataClause::acc_attach)
1024 return emitError(
1025 "data clause associated with attach operation must match its intent");
1026 if (failed(checkVarAndVarType(*this)))
1027 return failure();
1028 if (failed(checkVarAndAccVar(*this)))
1029 return failure();
1030 if (failed(checkNoModifier(*this)))
1031 return failure();
1032 return success();
1033}
1034
1035//===----------------------------------------------------------------------===//
1036// DeclareDeviceResidentOp
1037//===----------------------------------------------------------------------===//
1038
1039LogicalResult acc::DeclareDeviceResidentOp::verify() {
1040 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
1041 return emitError("data clause associated with device_resident operation "
1042 "must match its intent");
1043 if (failed(checkVarAndVarType(*this)))
1044 return failure();
1045 if (failed(checkVarAndAccVar(*this)))
1046 return failure();
1047 if (failed(checkNoModifier(*this)))
1048 return failure();
1049 return success();
1050}
1051
1052//===----------------------------------------------------------------------===//
1053// DeclareLinkOp
1054//===----------------------------------------------------------------------===//
1055
1056LogicalResult acc::DeclareLinkOp::verify() {
1057 if (getDataClause() != acc::DataClause::acc_declare_link)
1058 return emitError(
1059 "data clause associated with link operation must match its intent");
1060 if (failed(checkVarAndVarType(*this)))
1061 return failure();
1062 if (failed(checkVarAndAccVar(*this)))
1063 return failure();
1064 if (failed(checkNoModifier(*this)))
1065 return failure();
1066 return success();
1067}
1068
1069//===----------------------------------------------------------------------===//
1070// CopyoutOp
1071//===----------------------------------------------------------------------===//
1072LogicalResult acc::CopyoutOp::verify() {
1073 // Test for all clauses this operation can be decomposed from:
1074 if (getDataClause() != acc::DataClause::acc_copyout &&
1075 getDataClause() != acc::DataClause::acc_copyout_zero &&
1076 getDataClause() != acc::DataClause::acc_copy &&
1077 getDataClause() != acc::DataClause::acc_reduction)
1078 return emitError(
1079 "data clause associated with copyout operation must match its intent"
1080 " or specify original clause this operation was decomposed from");
1081 if (!getVar() || !getAccVar())
1082 return emitError("must have both host and device pointers");
1083 if (failed(checkVarAndVarType(*this)))
1084 return failure();
1085 if (failed(checkVarAndAccVar(*this)))
1086 return failure();
1087 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1088 acc::DataClauseModifier::always |
1089 acc::DataClauseModifier::capture)))
1090 return failure();
1091 return success();
1092}
1093
1094bool acc::CopyoutOp::isCopyoutZero() {
1095 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1096 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// DeleteOp
1101//===----------------------------------------------------------------------===//
1102LogicalResult acc::DeleteOp::verify() {
1103 // Test for all clauses this operation can be decomposed from:
1104 if (getDataClause() != acc::DataClause::acc_delete &&
1105 getDataClause() != acc::DataClause::acc_create &&
1106 getDataClause() != acc::DataClause::acc_create_zero &&
1107 getDataClause() != acc::DataClause::acc_copyin &&
1108 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1109 getDataClause() != acc::DataClause::acc_present &&
1110 getDataClause() != acc::DataClause::acc_no_create &&
1111 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1112 getDataClause() != acc::DataClause::acc_declare_link)
1113 return emitError(
1114 "data clause associated with delete operation must match its intent"
1115 " or specify original clause this operation was decomposed from");
1116 if (!getAccVar())
1117 return emitError("must have device pointer");
1118 // This op is the exit part of copyin and create - thus allow all modifiers
1119 // allowed on either case.
1120 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1121 acc::DataClauseModifier::readonly |
1122 acc::DataClauseModifier::always |
1123 acc::DataClauseModifier::capture)))
1124 return failure();
1125 return success();
1126}
1127
1128//===----------------------------------------------------------------------===//
1129// DetachOp
1130//===----------------------------------------------------------------------===//
1131LogicalResult acc::DetachOp::verify() {
1132 // Test for all clauses this operation can be decomposed from:
1133 if (getDataClause() != acc::DataClause::acc_detach &&
1134 getDataClause() != acc::DataClause::acc_attach)
1135 return emitError(
1136 "data clause associated with detach operation must match its intent"
1137 " or specify original clause this operation was decomposed from");
1138 if (!getAccVar())
1139 return emitError("must have device pointer");
1140 if (failed(checkNoModifier(*this)))
1141 return failure();
1142 return success();
1143}
1144
1145//===----------------------------------------------------------------------===//
1146// HostOp
1147//===----------------------------------------------------------------------===//
1148LogicalResult acc::UpdateHostOp::verify() {
1149 // Test for all clauses this operation can be decomposed from:
1150 if (getDataClause() != acc::DataClause::acc_update_host &&
1151 getDataClause() != acc::DataClause::acc_update_self)
1152 return emitError(
1153 "data clause associated with host operation must match its intent"
1154 " or specify original clause this operation was decomposed from");
1155 if (!getVar() || !getAccVar())
1156 return emitError("must have both host and device pointers");
1157 if (failed(checkVarAndVarType(*this)))
1158 return failure();
1159 if (failed(checkVarAndAccVar(*this)))
1160 return failure();
1161 if (failed(checkNoModifier(*this)))
1162 return failure();
1163 return success();
1164}
1165
1166//===----------------------------------------------------------------------===//
1167// DeviceOp
1168//===----------------------------------------------------------------------===//
1169LogicalResult acc::UpdateDeviceOp::verify() {
1170 // Test for all clauses this operation can be decomposed from:
1171 if (getDataClause() != acc::DataClause::acc_update_device)
1172 return emitError(
1173 "data clause associated with device operation must match its intent"
1174 " or specify original clause this operation was decomposed from");
1175 if (failed(checkVarAndVarType(*this)))
1176 return failure();
1177 if (failed(checkVarAndAccVar(*this)))
1178 return failure();
1179 if (failed(checkNoModifier(*this)))
1180 return failure();
1181 return success();
1182}
1183
1184//===----------------------------------------------------------------------===//
1185// UseDeviceOp
1186//===----------------------------------------------------------------------===//
1187LogicalResult acc::UseDeviceOp::verify() {
1188 // Test for all clauses this operation can be decomposed from:
1189 if (getDataClause() != acc::DataClause::acc_use_device)
1190 return emitError(
1191 "data clause associated with use_device operation must match its intent"
1192 " or specify original clause this operation was decomposed from");
1193 if (failed(checkVarAndVarType(*this)))
1194 return failure();
1195 if (failed(checkVarAndAccVar(*this)))
1196 return failure();
1197 if (failed(checkNoModifier(*this)))
1198 return failure();
1199 return success();
1200}
1201
1202//===----------------------------------------------------------------------===//
1203// CacheOp
1204//===----------------------------------------------------------------------===//
1205LogicalResult acc::CacheOp::verify() {
1206 // Test for all clauses this operation can be decomposed from:
1207 if (getDataClause() != acc::DataClause::acc_cache &&
1208 getDataClause() != acc::DataClause::acc_cache_readonly)
1209 return emitError(
1210 "data clause associated with cache operation must match its intent"
1211 " or specify original clause this operation was decomposed from");
1212 if (failed(checkVarAndVarType(*this)))
1213 return failure();
1214 if (failed(checkVarAndAccVar(*this)))
1215 return failure();
1216 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
1217 return failure();
1218 return success();
1219}
1220
1221bool acc::CacheOp::isCacheReadonly() {
1222 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1223 acc::bitEnumContainsAny(getModifiers(),
1224 acc::DataClauseModifier::readonly);
1225}
1226
1227//===----------------------------------------------------------------------===//
1228// Data entry/exit operations - getEffects implementations
1229//===----------------------------------------------------------------------===//
1230
1231// This function returns true iff the given operation is enclosed
1232// in any ACC_COMPUTE_CONSTRUCT_OPS operation.
1233// It is quite alike acc::getEnclosingComputeOp() utility,
1234// but we cannot use it here.
1236 mlir::Operation *parentOp = op->getParentOp();
1237 while (parentOp) {
1238 if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
1239 return true;
1240 parentOp = parentOp->getParentOp();
1241 }
1242 return false;
1243}
1244
1245/// Helper to add an effect on an operand, referenced by its mutable range.
1246template <typename EffectTy>
1249 &effects,
1250 MutableOperandRange operand) {
1251 for (unsigned i = 0, e = operand.size(); i < e; ++i)
1252 effects.emplace_back(EffectTy::get(), &operand[i]);
1253}
1254
1255/// Helper to add an effect on a result value.
1256template <typename EffectTy>
1259 &effects,
1260 Value result) {
1261 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
1262}
1263
1264// PrivateOp: accVar result write.
1265void acc::PrivateOp::getEffects(
1267 &effects) {
1268 // If acc.private is enclosed into a compute operation,
1269 // then it denotes the device side privatization, hence
1270 // it does not access the CurrentDeviceIdResource.
1271 if (!isEnclosedIntoComputeOp(getOperation()))
1272 effects.emplace_back(MemoryEffects::Read::get(),
1274 // TODO: should this be MemoryEffects::Allocate?
1276}
1277
1278// FirstprivateOp: var read, accVar result write.
1279void acc::FirstprivateOp::getEffects(
1281 &effects) {
1282 // If acc.firstprivate is enclosed into a compute operation,
1283 // then it denotes the device side privatization, hence
1284 // it does not access the CurrentDeviceIdResource.
1285 if (!isEnclosedIntoComputeOp(getOperation()))
1286 effects.emplace_back(MemoryEffects::Read::get(),
1288 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1290}
1291
1292// FirstprivateMapInitialOp: var read, accVar result write.
1293void acc::FirstprivateMapInitialOp::getEffects(
1295 &effects) {
1296 effects.emplace_back(MemoryEffects::Read::get(),
1298 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1300}
1301
1302// ReductionOp: var read, accVar result write.
1303void acc::ReductionOp::getEffects(
1305 &effects) {
1306 // If acc.reduction is enclosed into a compute operation,
1307 // then it denotes the device side reduction, hence
1308 // it does not access the CurrentDeviceIdResource.
1309 if (!isEnclosedIntoComputeOp(getOperation()))
1310 effects.emplace_back(MemoryEffects::Read::get(),
1312 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1314}
1315
1316// DevicePtrOp: RuntimeCounters read.
1317void acc::DevicePtrOp::getEffects(
1319 &effects) {
1320 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1321 effects.emplace_back(MemoryEffects::Read::get(),
1323}
1324
1325// PresentOp: RuntimeCounters read+write.
1326void acc::PresentOp::getEffects(
1328 &effects) {
1329 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1330 effects.emplace_back(MemoryEffects::Write::get(),
1332 effects.emplace_back(MemoryEffects::Read::get(),
1334}
1335
1336// CopyinOp: RuntimeCounters read+write, var read, accVar result write.
1337void acc::CopyinOp::getEffects(
1339 &effects) {
1340 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1341 effects.emplace_back(MemoryEffects::Write::get(),
1343 effects.emplace_back(MemoryEffects::Read::get(),
1345 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1347}
1348
1349// CreateOp: RuntimeCounters read+write, accVar result write.
1350void acc::CreateOp::getEffects(
1352 &effects) {
1353 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1354 effects.emplace_back(MemoryEffects::Write::get(),
1356 effects.emplace_back(MemoryEffects::Read::get(),
1358 // TODO: should this be MemoryEffects::Allocate?
1360}
1361
1362// NoCreateOp: RuntimeCounters read+write.
1363void acc::NoCreateOp::getEffects(
1365 &effects) {
1366 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1367 effects.emplace_back(MemoryEffects::Write::get(),
1369 effects.emplace_back(MemoryEffects::Read::get(),
1371}
1372
1373// AttachOp: RuntimeCounters read+write, var read.
1374void acc::AttachOp::getEffects(
1376 &effects) {
1377 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1378 effects.emplace_back(MemoryEffects::Write::get(),
1380 effects.emplace_back(MemoryEffects::Read::get(),
1382 // TODO: should we also add MemoryEffects::Write?
1383 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1384}
1385
1386// GetDevicePtrOp: RuntimeCounters read.
1387void acc::GetDevicePtrOp::getEffects(
1389 &effects) {
1390 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1391 effects.emplace_back(MemoryEffects::Read::get(),
1393}
1394
1395// UpdateDeviceOp: var read, accVar result write.
1396void acc::UpdateDeviceOp::getEffects(
1398 &effects) {
1399 effects.emplace_back(MemoryEffects::Read::get(),
1401 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1403}
1404
1405// UseDeviceOp: RuntimeCounters read.
1406void acc::UseDeviceOp::getEffects(
1408 &effects) {
1409 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1410 effects.emplace_back(MemoryEffects::Read::get(),
1412}
1413
1414// DeclareDeviceResidentOp: RuntimeCounters write, var read.
1415void acc::DeclareDeviceResidentOp::getEffects(
1417 &effects) {
1418 effects.emplace_back(MemoryEffects::Write::get(),
1420 effects.emplace_back(MemoryEffects::Read::get(),
1422 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1423}
1424
1425// DeclareLinkOp: RuntimeCounters write, var read.
1426void acc::DeclareLinkOp::getEffects(
1428 &effects) {
1429 effects.emplace_back(MemoryEffects::Write::get(),
1431 effects.emplace_back(MemoryEffects::Read::get(),
1433 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1434}
1435
1436// CacheOp: NoMemoryEffect
1437void acc::CacheOp::getEffects(
1439 &effects) {}
1440
1441// CopyoutOp: RuntimeCounters read+write, accVar read, var write.
1442void acc::CopyoutOp::getEffects(
1444 &effects) {
1445 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1446 effects.emplace_back(MemoryEffects::Write::get(),
1448 effects.emplace_back(MemoryEffects::Read::get(),
1450 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1451 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1452}
1453
1454// DeleteOp: RuntimeCounters read+write, accVar read.
1455void acc::DeleteOp::getEffects(
1457 &effects) {
1458 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1459 effects.emplace_back(MemoryEffects::Write::get(),
1461 effects.emplace_back(MemoryEffects::Read::get(),
1463 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1464}
1465
1466// DetachOp: RuntimeCounters read+write, accVar read.
1467void acc::DetachOp::getEffects(
1469 &effects) {
1470 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1471 effects.emplace_back(MemoryEffects::Write::get(),
1473 effects.emplace_back(MemoryEffects::Read::get(),
1475 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1476}
1477
1478// UpdateHostOp: RuntimeCounters read+write, accVar read, var write.
1479void acc::UpdateHostOp::getEffects(
1481 &effects) {
1482 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1483 effects.emplace_back(MemoryEffects::Write::get(),
1485 effects.emplace_back(MemoryEffects::Read::get(),
1487 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1488 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1489}
1490
1491template <typename StructureOp>
1492static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
1493 unsigned nRegions = 1) {
1494
1496 for (unsigned i = 0; i < nRegions; ++i)
1497 regions.push_back(state.addRegion());
1498
1499 for (Region *region : regions)
1500 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
1501 return failure();
1502
1503 return success();
1504}
1505
1507 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1508}
1509
1510namespace {
1511/// Pattern to remove operation without region that have constant false `ifCond`
1512/// and remove the condition from the operation if the `ifCond` is a true
1513/// constant.
1514template <typename OpTy>
1515struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1516 using OpRewritePattern<OpTy>::OpRewritePattern;
1517
1518 LogicalResult matchAndRewrite(OpTy op,
1519 PatternRewriter &rewriter) const override {
1520 // Early return if there is no condition.
1521 Value ifCond = op.getIfCond();
1522 if (!ifCond)
1523 return failure();
1524
1525 IntegerAttr constAttr;
1526 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1527 return failure();
1528 if (constAttr.getInt())
1529 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1530 else
1531 rewriter.eraseOp(op);
1532
1533 return success();
1534 }
1535};
1536
1537/// Replaces the given op with the contents of the given single-block region,
1538/// using the operands of the block terminator to replace operation results.
1539static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1540 Region &region, ValueRange blockArgs = {}) {
1541 assert(region.hasOneBlock() && "expected single-block region");
1542 Block *block = &region.front();
1543 Operation *terminator = block->getTerminator();
1544 ValueRange results = terminator->getOperands();
1545 rewriter.inlineBlockBefore(block, op, blockArgs);
1546 rewriter.replaceOp(op, results);
1547 rewriter.eraseOp(terminator);
1548}
1549
1550/// Pattern to remove operation with region that have constant false `ifCond`
1551/// and remove the condition from the operation if the `ifCond` is constant
1552/// true.
1553template <typename OpTy>
1554struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1555 using OpRewritePattern<OpTy>::OpRewritePattern;
1556
1557 LogicalResult matchAndRewrite(OpTy op,
1558 PatternRewriter &rewriter) const override {
1559 // Early return if there is no condition.
1560 Value ifCond = op.getIfCond();
1561 if (!ifCond)
1562 return failure();
1563
1564 IntegerAttr constAttr;
1565 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1566 return failure();
1567 if (constAttr.getInt())
1568 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1569 else
1570 replaceOpWithRegion(rewriter, op, op.getRegion());
1571
1572 return success();
1573 }
1574};
1575
1576/// Remove empty acc.kernel_environment operations. If the operation has wait
1577/// operands, create a acc.wait operation to preserve synchronization.
1578struct RemoveEmptyKernelEnvironment
1579 : public OpRewritePattern<acc::KernelEnvironmentOp> {
1580 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1581
1582 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1583 PatternRewriter &rewriter) const override {
1584 assert(op->getNumRegions() == 1 && "expected op to have one region");
1585
1586 Block &block = op.getRegion().front();
1587 if (!block.empty())
1588 return failure();
1589
1590 // Conservatively disable canonicalization of empty acc.kernel_environment
1591 // operations if the wait operands in the kernel_environment cannot be fully
1592 // represented by acc.wait operation.
1593
1594 // Disable canonicalization if device type is not the default
1595 if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1596 for (auto attr : deviceTypeAttr) {
1597 if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1598 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1599 return failure();
1600 }
1601 }
1602 }
1603
1604 // Disable canonicalization if any wait segment has a devnum
1605 if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1606 for (auto attr : hasDevnumAttr) {
1607 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1608 if (boolAttr.getValue())
1609 return failure();
1610 }
1611 }
1612 }
1613
1614 // Disable canonicalization if there are multiple wait segments
1615 if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1616 if (segmentsAttr.size() > 1)
1617 return failure();
1618 }
1619
1620 // Remove empty kernel environment.
1621 // Preserve synchronization by creating acc.wait operation if needed.
1622 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1623 rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
1624 /*asyncOperand=*/Value(),
1625 /*waitDevnum=*/Value(),
1626 /*async=*/nullptr,
1627 /*ifCond=*/Value());
1628 else
1629 rewriter.eraseOp(op);
1630
1631 return success();
1632 }
1633};
1634
1635//===----------------------------------------------------------------------===//
1636// Recipe Region Helpers
1637//===----------------------------------------------------------------------===//
1638
1639/// Create and populate an init region for privatization recipes.
1640/// Returns success if the region is populated, failure otherwise.
1641/// Sets needsFree to indicate if the allocated memory requires deallocation.
1642static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1643 Region &initRegion, Type varType,
1644 StringRef varName, ValueRange bounds,
1645 bool &needsFree) {
1646 // Create init block with arguments: original value + bounds
1647 SmallVector<Type> argTypes{varType};
1648 SmallVector<Location> argLocs{loc};
1649 for (Value bound : bounds) {
1650 argTypes.push_back(bound.getType());
1651 argLocs.push_back(loc);
1652 }
1653
1654 Block *initBlock = builder.createBlock(&initRegion);
1655 initBlock->addArguments(argTypes, argLocs);
1656 builder.setInsertionPointToStart(initBlock);
1657
1658 Value privatizedValue;
1659
1660 // Get the block argument that represents the original variable
1661 Value blockArgVar = initBlock->getArgument(0);
1662
1663 // Generate init region body based on variable type
1664 if (isa<MappableType>(varType)) {
1665 auto mappableTy = cast<MappableType>(varType);
1666 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1667 privatizedValue = mappableTy.generatePrivateInit(
1668 builder, loc, typedVar, varName, bounds, {}, needsFree);
1669 if (!privatizedValue)
1670 return failure();
1671 } else {
1672 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1673 auto pointerLikeTy = cast<PointerLikeType>(varType);
1674 // Use PointerLikeType's allocation API with the block argument
1675 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1676 blockArgVar, needsFree);
1677 if (!privatizedValue)
1678 return failure();
1679 }
1680
1681 // Add yield operation to init block
1682 acc::YieldOp::create(builder, loc, privatizedValue);
1683
1684 return success();
1685}
1686
1687/// Create and populate a copy region for firstprivate recipes.
1688/// Returns success if the region is populated, failure otherwise.
1689/// TODO: Handle MappableType - it does not yet have a copy API.
1690static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1691 Region &copyRegion, Type varType,
1692 ValueRange bounds) {
1693 // Create copy block with arguments: original value + privatized value +
1694 // bounds
1695 SmallVector<Type> copyArgTypes{varType, varType};
1696 SmallVector<Location> copyArgLocs{loc, loc};
1697 for (Value bound : bounds) {
1698 copyArgTypes.push_back(bound.getType());
1699 copyArgLocs.push_back(loc);
1700 }
1701
1702 Block *copyBlock = builder.createBlock(&copyRegion);
1703 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1704 builder.setInsertionPointToStart(copyBlock);
1705
1706 bool isMappable = isa<MappableType>(varType);
1707 bool isPointerLike = isa<PointerLikeType>(varType);
1708 // TODO: Handle MappableType - it does not yet have a copy API.
1709 // Otherwise, for now just fallback to pointer-like behavior.
1710 if (isMappable && !isPointerLike)
1711 return failure();
1712
1713 // Generate copy region body based on variable type
1714 if (isPointerLike) {
1715 auto pointerLikeTy = cast<PointerLikeType>(varType);
1716 Value originalArg = copyBlock->getArgument(0);
1717 Value privatizedArg = copyBlock->getArgument(1);
1718
1719 // Generate copy operation using PointerLikeType interface
1720 if (!pointerLikeTy.genCopy(
1721 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1722 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1723 return failure();
1724 }
1725
1726 // Add terminator to copy block
1727 acc::TerminatorOp::create(builder, loc);
1728
1729 return success();
1730}
1731
1732/// Create and populate a destroy region for privatization recipes.
1733/// Returns success if the region is populated, failure otherwise.
1734static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1735 Region &destroyRegion, Type varType,
1736 Value allocRes, ValueRange bounds) {
1737 // Create destroy block with arguments: original value + privatized value +
1738 // bounds
1739 SmallVector<Type> destroyArgTypes{varType, varType};
1740 SmallVector<Location> destroyArgLocs{loc, loc};
1741 for (Value bound : bounds) {
1742 destroyArgTypes.push_back(bound.getType());
1743 destroyArgLocs.push_back(loc);
1744 }
1745
1746 Block *destroyBlock = builder.createBlock(&destroyRegion);
1747 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1748 builder.setInsertionPointToStart(destroyBlock);
1749
1750 auto varToFree =
1751 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1752 if (isa<MappableType>(varType)) {
1753 auto mappableTy = cast<MappableType>(varType);
1754 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds))
1755 return failure();
1756 } else {
1757 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1758 auto pointerLikeTy = cast<PointerLikeType>(varType);
1759 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1760 return failure();
1761 }
1762
1763 acc::TerminatorOp::create(builder, loc);
1764 return success();
1765}
1766
1767} // namespace
1768
1769//===----------------------------------------------------------------------===//
1770// PrivateRecipeOp
1771//===----------------------------------------------------------------------===//
1772
1774 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1775 Type type, bool verifyYield, bool optional = false) {
1776 if (optional && region.empty())
1777 return success();
1778
1779 if (region.empty())
1780 return op->emitOpError() << "expects non-empty " << regionName << " region";
1781 Block &firstBlock = region.front();
1782 if (firstBlock.getNumArguments() < 1 ||
1783 firstBlock.getArgument(0).getType() != type)
1784 return op->emitOpError() << "expects " << regionName
1785 << " region first "
1786 "argument of the "
1787 << regionType << " type";
1788
1789 if (verifyYield) {
1790 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1791 if (yieldOp.getOperands().size() != 1 ||
1792 yieldOp.getOperands().getTypes()[0] != type)
1793 return op->emitOpError() << "expects " << regionName
1794 << " region to "
1795 "yield a value of the "
1796 << regionType << " type";
1797 }
1798 }
1799 return success();
1800}
1801
1802LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1803 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1804 "privatization", "init", getType(),
1805 /*verifyYield=*/false)))
1806 return failure();
1808 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1809 /*verifyYield=*/false, /*optional=*/true)))
1810 return failure();
1811 return success();
1812}
1813
1814std::optional<PrivateRecipeOp>
1815PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1816 StringRef recipeName, Type varType,
1817 StringRef varName, ValueRange bounds) {
1818 // First, validate that we can handle this variable type
1819 bool isMappable = isa<MappableType>(varType);
1820 bool isPointerLike = isa<PointerLikeType>(varType);
1821
1822 // Unsupported type
1823 if (!isMappable && !isPointerLike)
1824 return std::nullopt;
1825
1826 OpBuilder::InsertionGuard guard(builder);
1827
1828 // Create the recipe operation first so regions have proper parent context
1829 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1830
1831 // Populate the init region
1832 bool needsFree = false;
1833 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1834 varName, bounds, needsFree))) {
1835 recipe.erase();
1836 return std::nullopt;
1837 }
1838
1839 // Only create destroy region if the allocation needs deallocation
1840 if (needsFree) {
1841 // Extract the allocated value from the init block's yield operation
1842 auto yieldOp =
1843 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1844 Value allocRes = yieldOp.getOperand(0);
1845
1846 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1847 varType, allocRes, bounds))) {
1848 recipe.erase();
1849 return std::nullopt;
1850 }
1851 }
1852
1853 return recipe;
1854}
1855
1856std::optional<PrivateRecipeOp>
1857PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1858 StringRef recipeName,
1859 FirstprivateRecipeOp firstprivRecipe) {
1860 // Create the private.recipe op with the same type as the firstprivate.recipe.
1861 OpBuilder::InsertionGuard guard(builder);
1862 auto varType = firstprivRecipe.getType();
1863 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1864
1865 // Clone the init region
1866 IRMapping mapping;
1867 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1868
1869 // Clone destroy region if the firstprivate.recipe has one.
1870 if (!firstprivRecipe.getDestroyRegion().empty()) {
1871 IRMapping mapping;
1872 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1873 mapping);
1874 }
1875 return recipe;
1876}
1877
1878//===----------------------------------------------------------------------===//
1879// FirstprivateRecipeOp
1880//===----------------------------------------------------------------------===//
1881
1882LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1883 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1884 "privatization", "init", getType(),
1885 /*verifyYield=*/false)))
1886 return failure();
1887
1888 if (getCopyRegion().empty())
1889 return emitOpError() << "expects non-empty copy region";
1890
1891 Block &firstBlock = getCopyRegion().front();
1892 if (firstBlock.getNumArguments() < 2 ||
1893 firstBlock.getArgument(0).getType() != getType())
1894 return emitOpError() << "expects copy region with two arguments of the "
1895 "privatization type";
1896
1897 if (getDestroyRegion().empty())
1898 return success();
1899
1900 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1901 "privatization", "destroy",
1902 getType(), /*verifyYield=*/false)))
1903 return failure();
1904
1905 return success();
1906}
1907
1908std::optional<FirstprivateRecipeOp>
1909FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1910 StringRef recipeName, Type varType,
1911 StringRef varName, ValueRange bounds) {
1912 // First, validate that we can handle this variable type
1913 bool isMappable = isa<MappableType>(varType);
1914 bool isPointerLike = isa<PointerLikeType>(varType);
1915
1916 // Unsupported type
1917 if (!isMappable && !isPointerLike)
1918 return std::nullopt;
1919
1920 OpBuilder::InsertionGuard guard(builder);
1921
1922 // Create the recipe operation first so regions have proper parent context
1923 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1924
1925 // Populate the init region
1926 bool needsFree = false;
1927 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1928 varName, bounds, needsFree))) {
1929 recipe.erase();
1930 return std::nullopt;
1931 }
1932
1933 // Populate the copy region
1934 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1935 bounds))) {
1936 recipe.erase();
1937 return std::nullopt;
1938 }
1939
1940 // Only create destroy region if the allocation needs deallocation
1941 if (needsFree) {
1942 // Extract the allocated value from the init block's yield operation
1943 auto yieldOp =
1944 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1945 Value allocRes = yieldOp.getOperand(0);
1946
1947 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1948 varType, allocRes, bounds))) {
1949 recipe.erase();
1950 return std::nullopt;
1951 }
1952 }
1953
1954 return recipe;
1955}
1956
1957//===----------------------------------------------------------------------===//
1958// ReductionRecipeOp
1959//===----------------------------------------------------------------------===//
1960
1961LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1962 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1963 "init", getType(),
1964 /*verifyYield=*/false)))
1965 return failure();
1966
1967 if (getCombinerRegion().empty())
1968 return emitOpError() << "expects non-empty combiner region";
1969
1970 Block &reductionBlock = getCombinerRegion().front();
1971 if (reductionBlock.getNumArguments() < 2 ||
1972 reductionBlock.getArgument(0).getType() != getType() ||
1973 reductionBlock.getArgument(1).getType() != getType())
1974 return emitOpError() << "expects combiner region with the first two "
1975 << "arguments of the reduction type";
1976
1977 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1978 if (yieldOp.getOperands().size() != 1 ||
1979 yieldOp.getOperands().getTypes()[0] != getType())
1980 return emitOpError() << "expects combiner region to yield a value "
1981 "of the reduction type";
1982 }
1983
1984 return success();
1985}
1986
1987//===----------------------------------------------------------------------===//
1988// ParallelOp
1989//===----------------------------------------------------------------------===//
1990
1991/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1992template <typename Op>
1993static LogicalResult checkDataOperands(Op op,
1994 const mlir::ValueRange &operands) {
1995 for (mlir::Value operand : operands)
1996 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1997 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1998 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1999 operand.getDefiningOp()))
2000 return op.emitError(
2001 "expect data entry/exit operation or acc.getdeviceptr "
2002 "as defining op");
2003 return success();
2004}
2005
2006template <typename OpT, typename RecipeOpT>
2007static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
2008 const mlir::ValueRange &operands,
2009 llvm::StringRef operandName) {
2011 for (mlir::Value operand : operands) {
2012 if (!mlir::isa<OpT>(operand.getDefiningOp()))
2013 return accConstructOp->emitOpError()
2014 << "expected " << operandName << " as defining op";
2015 if (!set.insert(operand).second)
2016 return accConstructOp->emitOpError()
2017 << operandName << " operand appears more than once";
2018 }
2019 return success();
2020}
2021
2022unsigned ParallelOp::getNumDataOperands() {
2023 return getReductionOperands().size() + getPrivateOperands().size() +
2024 getFirstprivateOperands().size() + getDataClauseOperands().size();
2025}
2026
2027Value ParallelOp::getDataOperand(unsigned i) {
2028 unsigned numOptional = getAsyncOperands().size();
2029 numOptional += getNumGangs().size();
2030 numOptional += getNumWorkers().size();
2031 numOptional += getVectorLength().size();
2032 numOptional += getIfCond() ? 1 : 0;
2033 numOptional += getSelfCond() ? 1 : 0;
2034 return getOperand(getWaitOperands().size() + numOptional + i);
2035}
2036
2037template <typename Op>
2038static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
2039 ArrayAttr deviceTypes,
2040 llvm::StringRef keyword) {
2041 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
2042 return op.emitOpError() << keyword << " operands count must match "
2043 << keyword << " device_type count";
2044 return success();
2045}
2046
2047template <typename Op>
2049 Op op, OperandRange operands, DenseI32ArrayAttr segments,
2050 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
2051 std::size_t numOperandsInSegments = 0;
2052 std::size_t nbOfSegments = 0;
2053
2054 if (segments) {
2055 for (auto segCount : segments.asArrayRef()) {
2056 if (maxInSegment != 0 && segCount > maxInSegment)
2057 return op.emitOpError() << keyword << " expects a maximum of "
2058 << maxInSegment << " values per segment";
2059 numOperandsInSegments += segCount;
2060 ++nbOfSegments;
2061 }
2062 }
2063
2064 if ((numOperandsInSegments != operands.size()) ||
2065 (!deviceTypes && !operands.empty()))
2066 return op.emitOpError()
2067 << keyword << " operand count does not match count in segments";
2068 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
2069 return op.emitOpError()
2070 << keyword << " segment count does not match device_type count";
2071 return success();
2072}
2073
2074LogicalResult acc::ParallelOp::verify() {
2075 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2076 mlir::acc::PrivateRecipeOp>(
2077 *this, getPrivateOperands(), "private")))
2078 return failure();
2079 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2080 mlir::acc::FirstprivateRecipeOp>(
2081 *this, getFirstprivateOperands(), "firstprivate")))
2082 return failure();
2083 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2084 mlir::acc::ReductionRecipeOp>(
2085 *this, getReductionOperands(), "reduction")))
2086 return failure();
2087
2089 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2090 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2091 return failure();
2092
2094 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2095 getWaitOperandsDeviceTypeAttr(), "wait")))
2096 return failure();
2097
2098 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2099 getNumWorkersDeviceTypeAttr(),
2100 "num_workers")))
2101 return failure();
2102
2103 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2104 getVectorLengthDeviceTypeAttr(),
2105 "vector_length")))
2106 return failure();
2107
2109 getAsyncOperandsDeviceTypeAttr(),
2110 "async")))
2111 return failure();
2112
2114 return failure();
2115
2116 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
2117}
2118
2119static mlir::Value
2120getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
2122 mlir::acc::DeviceType deviceType) {
2123 if (!arrayAttr)
2124 return {};
2125 if (auto pos = findSegment(*arrayAttr, deviceType))
2126 return range[*pos];
2127 return {};
2128}
2129
2130bool acc::ParallelOp::hasAsyncOnly() {
2131 return hasAsyncOnly(mlir::acc::DeviceType::None);
2132}
2133
2134bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2135 return hasDeviceType(getAsyncOnly(), deviceType);
2136}
2137
2138mlir::Value acc::ParallelOp::getAsyncValue() {
2139 return getAsyncValue(mlir::acc::DeviceType::None);
2140}
2141
2142mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2144 getAsyncOperands(), deviceType);
2145}
2146
2147mlir::Value acc::ParallelOp::getNumWorkersValue() {
2148 return getNumWorkersValue(mlir::acc::DeviceType::None);
2149}
2150
2152acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2153 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2154 deviceType);
2155}
2156
2157mlir::Value acc::ParallelOp::getVectorLengthValue() {
2158 return getVectorLengthValue(mlir::acc::DeviceType::None);
2159}
2160
2162acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2163 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2164 getVectorLength(), deviceType);
2165}
2166
2167mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
2168 return getNumGangsValues(mlir::acc::DeviceType::None);
2169}
2170
2172ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2173 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2174 getNumGangsSegments(), deviceType);
2175}
2176
2177bool acc::ParallelOp::hasWaitOnly() {
2178 return hasWaitOnly(mlir::acc::DeviceType::None);
2179}
2180
2181bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2182 return hasDeviceType(getWaitOnly(), deviceType);
2183}
2184
2185mlir::Operation::operand_range ParallelOp::getWaitValues() {
2186 return getWaitValues(mlir::acc::DeviceType::None);
2187}
2188
2190ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2192 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2193 getHasWaitDevnum(), deviceType);
2194}
2195
2196mlir::Value ParallelOp::getWaitDevnum() {
2197 return getWaitDevnum(mlir::acc::DeviceType::None);
2198}
2199
2200mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2201 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2202 getWaitOperandsSegments(), getHasWaitDevnum(),
2203 deviceType);
2204}
2205
2206void ParallelOp::build(mlir::OpBuilder &odsBuilder,
2207 mlir::OperationState &odsState,
2208 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
2209 mlir::ValueRange vectorLength,
2210 mlir::ValueRange asyncOperands,
2211 mlir::ValueRange waitOperands, mlir::Value ifCond,
2212 mlir::Value selfCond, mlir::ValueRange reductionOperands,
2213 mlir::ValueRange gangPrivateOperands,
2214 mlir::ValueRange gangFirstPrivateOperands,
2215 mlir::ValueRange dataClauseOperands) {
2216 ParallelOp::build(
2217 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
2218 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
2219 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
2220 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
2221 /*numGangsDeviceType=*/nullptr, numWorkers,
2222 /*numWorkersDeviceType=*/nullptr, vectorLength,
2223 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
2224 /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
2225 gangFirstPrivateOperands, dataClauseOperands,
2226 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
2227}
2228
2229void acc::ParallelOp::addNumWorkersOperand(
2230 MLIRContext *context, mlir::Value newValue,
2231 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2232 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2233 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2234 getNumWorkersMutable()));
2235}
2236void acc::ParallelOp::addVectorLengthOperand(
2237 MLIRContext *context, mlir::Value newValue,
2238 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2239 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2240 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2241 getVectorLengthMutable()));
2242}
2243
2244void acc::ParallelOp::addAsyncOnly(
2245 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2246 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2247 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2248}
2249
2250void acc::ParallelOp::addAsyncOperand(
2251 MLIRContext *context, mlir::Value newValue,
2252 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2253 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2254 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2255 getAsyncOperandsMutable()));
2256}
2257
2258void acc::ParallelOp::addNumGangsOperands(
2259 MLIRContext *context, mlir::ValueRange newValues,
2260 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2262 if (getNumGangsSegments())
2263 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2264
2265 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2266 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2267 getNumGangsMutable(), segments));
2268
2269 setNumGangsSegments(segments);
2270}
2271void acc::ParallelOp::addWaitOnly(
2272 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2273 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2274 effectiveDeviceTypes));
2275}
2276void acc::ParallelOp::addWaitOperands(
2277 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2278 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2279
2281 if (getWaitOperandsSegments())
2282 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2283
2284 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2285 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2286 getWaitOperandsMutable(), segments));
2287 setWaitOperandsSegments(segments);
2288
2290 if (getHasWaitDevnumAttr())
2291 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2292 hasDevnums.insert(
2293 hasDevnums.end(),
2294 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2295 mlir::BoolAttr::get(context, hasDevnum));
2296 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2297}
2298
2299void acc::ParallelOp::addPrivatization(MLIRContext *context,
2300 mlir::acc::PrivateOp op,
2301 mlir::acc::PrivateRecipeOp recipe) {
2302 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2303 getPrivateOperandsMutable().append(op.getResult());
2304}
2305
2306void acc::ParallelOp::addFirstPrivatization(
2307 MLIRContext *context, mlir::acc::FirstprivateOp op,
2308 mlir::acc::FirstprivateRecipeOp recipe) {
2309 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2310 getFirstprivateOperandsMutable().append(op.getResult());
2311}
2312
2313void acc::ParallelOp::addReduction(MLIRContext *context,
2314 mlir::acc::ReductionOp op,
2315 mlir::acc::ReductionRecipeOp recipe) {
2316 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2317 getReductionOperandsMutable().append(op.getResult());
2318}
2319
2320static ParseResult parseNumGangs(
2321 mlir::OpAsmParser &parser,
2323 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2324 mlir::DenseI32ArrayAttr &segments) {
2327
2328 do {
2329 if (failed(parser.parseLBrace()))
2330 return failure();
2331
2332 int32_t crtOperandsSize = operands.size();
2333 if (failed(parser.parseCommaSeparatedList(
2335 if (parser.parseOperand(operands.emplace_back()) ||
2336 parser.parseColonType(types.emplace_back()))
2337 return failure();
2338 return success();
2339 })))
2340 return failure();
2341 seg.push_back(operands.size() - crtOperandsSize);
2342
2343 if (failed(parser.parseRBrace()))
2344 return failure();
2345
2346 if (succeeded(parser.parseOptionalLSquare())) {
2347 if (parser.parseAttribute(attributes.emplace_back()) ||
2348 parser.parseRSquare())
2349 return failure();
2350 } else {
2351 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2352 parser.getContext(), mlir::acc::DeviceType::None));
2353 }
2354 } while (succeeded(parser.parseOptionalComma()));
2355
2356 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2357 attributes.end());
2358 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2359 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2360
2361 return success();
2362}
2363
2365 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2366 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2367 p << " [" << attr << "]";
2368}
2369
2371 mlir::OperandRange operands, mlir::TypeRange types,
2372 std::optional<mlir::ArrayAttr> deviceTypes,
2373 std::optional<mlir::DenseI32ArrayAttr> segments) {
2374 unsigned opIdx = 0;
2375 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2376 p << "{";
2377 llvm::interleaveComma(
2378 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2379 p << operands[opIdx] << " : " << operands[opIdx].getType();
2380 ++opIdx;
2381 });
2382 p << "}";
2383 printSingleDeviceType(p, it.value());
2384 });
2385}
2386
2388 mlir::OpAsmParser &parser,
2390 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2391 mlir::DenseI32ArrayAttr &segments) {
2394
2395 do {
2396 if (failed(parser.parseLBrace()))
2397 return failure();
2398
2399 int32_t crtOperandsSize = operands.size();
2400
2401 if (failed(parser.parseCommaSeparatedList(
2403 if (parser.parseOperand(operands.emplace_back()) ||
2404 parser.parseColonType(types.emplace_back()))
2405 return failure();
2406 return success();
2407 })))
2408 return failure();
2409
2410 seg.push_back(operands.size() - crtOperandsSize);
2411
2412 if (failed(parser.parseRBrace()))
2413 return failure();
2414
2415 if (succeeded(parser.parseOptionalLSquare())) {
2416 if (parser.parseAttribute(attributes.emplace_back()) ||
2417 parser.parseRSquare())
2418 return failure();
2419 } else {
2420 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2421 parser.getContext(), mlir::acc::DeviceType::None));
2422 }
2423 } while (succeeded(parser.parseOptionalComma()));
2424
2425 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2426 attributes.end());
2427 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2428 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2429
2430 return success();
2431}
2432
2435 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2436 std::optional<mlir::DenseI32ArrayAttr> segments) {
2437 unsigned opIdx = 0;
2438 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2439 p << "{";
2440 llvm::interleaveComma(
2441 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2442 p << operands[opIdx] << " : " << operands[opIdx].getType();
2443 ++opIdx;
2444 });
2445 p << "}";
2446 printSingleDeviceType(p, it.value());
2447 });
2448}
2449
2450static ParseResult parseWaitClause(
2451 mlir::OpAsmParser &parser,
2453 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2454 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2455 mlir::ArrayAttr &keywordOnly) {
2456 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2458
2459 bool needCommaBeforeOperands = false;
2460
2461 // Keyword only
2462 if (failed(parser.parseOptionalLParen())) {
2463 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2464 parser.getContext(), mlir::acc::DeviceType::None));
2465 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2466 return success();
2467 }
2468
2469 // Parse keyword only attributes
2470 if (succeeded(parser.parseOptionalLSquare())) {
2471 if (failed(parser.parseCommaSeparatedList([&]() {
2472 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2473 return failure();
2474 return success();
2475 })))
2476 return failure();
2477 if (parser.parseRSquare())
2478 return failure();
2479 needCommaBeforeOperands = true;
2480 }
2481
2482 if (needCommaBeforeOperands && failed(parser.parseComma()))
2483 return failure();
2484
2485 do {
2486 if (failed(parser.parseLBrace()))
2487 return failure();
2488
2489 int32_t crtOperandsSize = operands.size();
2490
2491 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2492 if (failed(parser.parseColon()))
2493 return failure();
2494 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2495 } else {
2496 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2497 }
2498
2499 if (failed(parser.parseCommaSeparatedList(
2501 if (parser.parseOperand(operands.emplace_back()) ||
2502 parser.parseColonType(types.emplace_back()))
2503 return failure();
2504 return success();
2505 })))
2506 return failure();
2507
2508 seg.push_back(operands.size() - crtOperandsSize);
2509
2510 if (failed(parser.parseRBrace()))
2511 return failure();
2512
2513 if (succeeded(parser.parseOptionalLSquare())) {
2514 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2515 parser.parseRSquare())
2516 return failure();
2517 } else {
2518 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2519 parser.getContext(), mlir::acc::DeviceType::None));
2520 }
2521 } while (succeeded(parser.parseOptionalComma()));
2522
2523 if (failed(parser.parseRParen()))
2524 return failure();
2525
2526 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2527 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2528 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2529 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2530
2531 return success();
2532}
2533
2534static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2535 if (!hasDeviceTypeValues(attrs))
2536 return false;
2537 if (attrs->size() != 1)
2538 return false;
2539 if (auto deviceTypeAttr =
2540 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2541 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2542 return false;
2543}
2544
2546 mlir::OperandRange operands, mlir::TypeRange types,
2547 std::optional<mlir::ArrayAttr> deviceTypes,
2548 std::optional<mlir::DenseI32ArrayAttr> segments,
2549 std::optional<mlir::ArrayAttr> hasDevNum,
2550 std::optional<mlir::ArrayAttr> keywordOnly) {
2551
2552 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2553 return;
2554
2555 p << "(";
2556
2557 printDeviceTypes(p, keywordOnly);
2558 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2559 p << ", ";
2560
2561 if (hasDeviceTypeValues(deviceTypes)) {
2562 unsigned opIdx = 0;
2563 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2564 p << "{";
2565 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2566 if (boolAttr && boolAttr.getValue())
2567 p << "devnum: ";
2568 llvm::interleaveComma(
2569 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2570 p << operands[opIdx] << " : " << operands[opIdx].getType();
2571 ++opIdx;
2572 });
2573 p << "}";
2574 printSingleDeviceType(p, it.value());
2575 });
2576 }
2577
2578 p << ")";
2579}
2580
2581static ParseResult parseDeviceTypeOperands(
2582 mlir::OpAsmParser &parser,
2584 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2586 if (failed(parser.parseCommaSeparatedList([&]() {
2587 if (parser.parseOperand(operands.emplace_back()) ||
2588 parser.parseColonType(types.emplace_back()))
2589 return failure();
2590 if (succeeded(parser.parseOptionalLSquare())) {
2591 if (parser.parseAttribute(attributes.emplace_back()) ||
2592 parser.parseRSquare())
2593 return failure();
2594 } else {
2595 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2596 parser.getContext(), mlir::acc::DeviceType::None));
2597 }
2598 return success();
2599 })))
2600 return failure();
2601 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2602 attributes.end());
2603 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2604 return success();
2605}
2606
2607static void
2609 mlir::OperandRange operands, mlir::TypeRange types,
2610 std::optional<mlir::ArrayAttr> deviceTypes) {
2611 if (!hasDeviceTypeValues(deviceTypes))
2612 return;
2613 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2614 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2615 printSingleDeviceType(p, std::get<0>(it));
2616 });
2617}
2618
2620 mlir::OpAsmParser &parser,
2622 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2623 mlir::ArrayAttr &keywordOnlyDeviceType) {
2624
2625 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2626 bool needCommaBeforeOperands = false;
2627
2628 if (failed(parser.parseOptionalLParen())) {
2629 // Keyword only
2630 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2631 parser.getContext(), mlir::acc::DeviceType::None));
2632 keywordOnlyDeviceType =
2633 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2634 return success();
2635 }
2636
2637 // Parse keyword only attributes
2638 if (succeeded(parser.parseOptionalLSquare())) {
2639 // Parse keyword only attributes
2640 if (failed(parser.parseCommaSeparatedList([&]() {
2641 if (parser.parseAttribute(
2642 keywordOnlyDeviceTypeAttributes.emplace_back()))
2643 return failure();
2644 return success();
2645 })))
2646 return failure();
2647 if (parser.parseRSquare())
2648 return failure();
2649 needCommaBeforeOperands = true;
2650 }
2651
2652 if (needCommaBeforeOperands && failed(parser.parseComma()))
2653 return failure();
2654
2656 if (failed(parser.parseCommaSeparatedList([&]() {
2657 if (parser.parseOperand(operands.emplace_back()) ||
2658 parser.parseColonType(types.emplace_back()))
2659 return failure();
2660 if (succeeded(parser.parseOptionalLSquare())) {
2661 if (parser.parseAttribute(attributes.emplace_back()) ||
2662 parser.parseRSquare())
2663 return failure();
2664 } else {
2665 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2666 parser.getContext(), mlir::acc::DeviceType::None));
2667 }
2668 return success();
2669 })))
2670 return failure();
2671
2672 if (failed(parser.parseRParen()))
2673 return failure();
2674
2675 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2676 attributes.end());
2677 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2678 return success();
2679}
2680
2683 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2684 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2685
2686 if (operands.begin() == operands.end() &&
2687 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2688 return;
2689 }
2690
2691 p << "(";
2692 printDeviceTypes(p, keywordOnlyDeviceTypes);
2693 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2694 hasDeviceTypeValues(deviceTypes))
2695 p << ", ";
2696 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2697 p << ")";
2698}
2699
2701 mlir::OpAsmParser &parser,
2702 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2703 mlir::Type &operandType, mlir::UnitAttr &attr) {
2704 // Keyword only
2705 if (failed(parser.parseOptionalLParen())) {
2706 attr = mlir::UnitAttr::get(parser.getContext());
2707 return success();
2708 }
2709
2711 if (failed(parser.parseOperand(op)))
2712 return failure();
2713 operand = op;
2714 if (failed(parser.parseColon()))
2715 return failure();
2716 if (failed(parser.parseType(operandType)))
2717 return failure();
2718 if (failed(parser.parseRParen()))
2719 return failure();
2720
2721 return success();
2722}
2723
2725 mlir::Operation *op,
2726 std::optional<mlir::Value> operand,
2727 mlir::Type operandType,
2728 mlir::UnitAttr attr) {
2729 if (attr)
2730 return;
2731
2732 p << "(";
2733 p.printOperand(*operand);
2734 p << " : ";
2735 p.printType(operandType);
2736 p << ")";
2737}
2738
2740 mlir::OpAsmParser &parser,
2742 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2743 // Keyword only
2744 if (failed(parser.parseOptionalLParen())) {
2745 attr = mlir::UnitAttr::get(parser.getContext());
2746 return success();
2747 }
2748
2749 if (failed(parser.parseCommaSeparatedList([&]() {
2750 if (parser.parseOperand(operands.emplace_back()))
2751 return failure();
2752 return success();
2753 })))
2754 return failure();
2755 if (failed(parser.parseColon()))
2756 return failure();
2757 if (failed(parser.parseCommaSeparatedList([&]() {
2758 if (parser.parseType(types.emplace_back()))
2759 return failure();
2760 return success();
2761 })))
2762 return failure();
2763 if (failed(parser.parseRParen()))
2764 return failure();
2765
2766 return success();
2767}
2768
2770 mlir::Operation *op,
2771 mlir::OperandRange operands,
2772 mlir::TypeRange types,
2773 mlir::UnitAttr attr) {
2774 if (attr)
2775 return;
2776
2777 p << "(";
2778 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2779 p << " : ";
2780 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2781 p << ")";
2782}
2783
2784static ParseResult
2786 mlir::acc::CombinedConstructsTypeAttr &attr) {
2787 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2788 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2789 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2790 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2791 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2792 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2793 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2794 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2795 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2796 } else {
2797 parser.emitError(parser.getCurrentLocation(),
2798 "expected compute construct name");
2799 return failure();
2800 }
2801 return success();
2802}
2803
2804static void
2806 mlir::acc::CombinedConstructsTypeAttr attr) {
2807 if (attr) {
2808 switch (attr.getValue()) {
2809 case mlir::acc::CombinedConstructsType::KernelsLoop:
2810 p << "kernels";
2811 break;
2812 case mlir::acc::CombinedConstructsType::ParallelLoop:
2813 p << "parallel";
2814 break;
2815 case mlir::acc::CombinedConstructsType::SerialLoop:
2816 p << "serial";
2817 break;
2818 };
2819 }
2820}
2821
2822//===----------------------------------------------------------------------===//
2823// SerialOp
2824//===----------------------------------------------------------------------===//
2825
2826unsigned SerialOp::getNumDataOperands() {
2827 return getReductionOperands().size() + getPrivateOperands().size() +
2828 getFirstprivateOperands().size() + getDataClauseOperands().size();
2829}
2830
2831Value SerialOp::getDataOperand(unsigned i) {
2832 unsigned numOptional = getAsyncOperands().size();
2833 numOptional += getIfCond() ? 1 : 0;
2834 numOptional += getSelfCond() ? 1 : 0;
2835 return getOperand(getWaitOperands().size() + numOptional + i);
2836}
2837
2838bool acc::SerialOp::hasAsyncOnly() {
2839 return hasAsyncOnly(mlir::acc::DeviceType::None);
2840}
2841
2842bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2843 return hasDeviceType(getAsyncOnly(), deviceType);
2844}
2845
2846mlir::Value acc::SerialOp::getAsyncValue() {
2847 return getAsyncValue(mlir::acc::DeviceType::None);
2848}
2849
2850mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2852 getAsyncOperands(), deviceType);
2853}
2854
2855bool acc::SerialOp::hasWaitOnly() {
2856 return hasWaitOnly(mlir::acc::DeviceType::None);
2857}
2858
2859bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2860 return hasDeviceType(getWaitOnly(), deviceType);
2861}
2862
2863mlir::Operation::operand_range SerialOp::getWaitValues() {
2864 return getWaitValues(mlir::acc::DeviceType::None);
2865}
2866
2868SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2870 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2871 getHasWaitDevnum(), deviceType);
2872}
2873
2874mlir::Value SerialOp::getWaitDevnum() {
2875 return getWaitDevnum(mlir::acc::DeviceType::None);
2876}
2877
2878mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2879 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2880 getWaitOperandsSegments(), getHasWaitDevnum(),
2881 deviceType);
2882}
2883
2884LogicalResult acc::SerialOp::verify() {
2885 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2886 mlir::acc::PrivateRecipeOp>(
2887 *this, getPrivateOperands(), "private")))
2888 return failure();
2889 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2890 mlir::acc::FirstprivateRecipeOp>(
2891 *this, getFirstprivateOperands(), "firstprivate")))
2892 return failure();
2893 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2894 mlir::acc::ReductionRecipeOp>(
2895 *this, getReductionOperands(), "reduction")))
2896 return failure();
2897
2899 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2900 getWaitOperandsDeviceTypeAttr(), "wait")))
2901 return failure();
2902
2904 getAsyncOperandsDeviceTypeAttr(),
2905 "async")))
2906 return failure();
2907
2909 return failure();
2910
2911 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2912}
2913
2914void acc::SerialOp::addAsyncOnly(
2915 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2916 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2917 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2918}
2919
2920void acc::SerialOp::addAsyncOperand(
2921 MLIRContext *context, mlir::Value newValue,
2922 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2923 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2924 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2925 getAsyncOperandsMutable()));
2926}
2927
2928void acc::SerialOp::addWaitOnly(
2929 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2930 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2931 effectiveDeviceTypes));
2932}
2933void acc::SerialOp::addWaitOperands(
2934 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2935 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2936
2938 if (getWaitOperandsSegments())
2939 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2940
2941 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2942 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2943 getWaitOperandsMutable(), segments));
2944 setWaitOperandsSegments(segments);
2945
2947 if (getHasWaitDevnumAttr())
2948 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2949 hasDevnums.insert(
2950 hasDevnums.end(),
2951 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2952 mlir::BoolAttr::get(context, hasDevnum));
2953 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2954}
2955
2956void acc::SerialOp::addPrivatization(MLIRContext *context,
2957 mlir::acc::PrivateOp op,
2958 mlir::acc::PrivateRecipeOp recipe) {
2959 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2960 getPrivateOperandsMutable().append(op.getResult());
2961}
2962
2963void acc::SerialOp::addFirstPrivatization(
2964 MLIRContext *context, mlir::acc::FirstprivateOp op,
2965 mlir::acc::FirstprivateRecipeOp recipe) {
2966 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2967 getFirstprivateOperandsMutable().append(op.getResult());
2968}
2969
2970void acc::SerialOp::addReduction(MLIRContext *context,
2971 mlir::acc::ReductionOp op,
2972 mlir::acc::ReductionRecipeOp recipe) {
2973 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2974 getReductionOperandsMutable().append(op.getResult());
2975}
2976
2977//===----------------------------------------------------------------------===//
2978// KernelsOp
2979//===----------------------------------------------------------------------===//
2980
2981unsigned KernelsOp::getNumDataOperands() {
2982 return getDataClauseOperands().size();
2983}
2984
2985Value KernelsOp::getDataOperand(unsigned i) {
2986 unsigned numOptional = getAsyncOperands().size();
2987 numOptional += getWaitOperands().size();
2988 numOptional += getNumGangs().size();
2989 numOptional += getNumWorkers().size();
2990 numOptional += getVectorLength().size();
2991 numOptional += getIfCond() ? 1 : 0;
2992 numOptional += getSelfCond() ? 1 : 0;
2993 return getOperand(numOptional + i);
2994}
2995
2996bool acc::KernelsOp::hasAsyncOnly() {
2997 return hasAsyncOnly(mlir::acc::DeviceType::None);
2998}
2999
3000bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3001 return hasDeviceType(getAsyncOnly(), deviceType);
3002}
3003
3004mlir::Value acc::KernelsOp::getAsyncValue() {
3005 return getAsyncValue(mlir::acc::DeviceType::None);
3006}
3007
3008mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3010 getAsyncOperands(), deviceType);
3011}
3012
3013mlir::Value acc::KernelsOp::getNumWorkersValue() {
3014 return getNumWorkersValue(mlir::acc::DeviceType::None);
3015}
3016
3018acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
3019 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
3020 deviceType);
3021}
3022
3023mlir::Value acc::KernelsOp::getVectorLengthValue() {
3024 return getVectorLengthValue(mlir::acc::DeviceType::None);
3025}
3026
3028acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
3029 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
3030 getVectorLength(), deviceType);
3031}
3032
3033mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
3034 return getNumGangsValues(mlir::acc::DeviceType::None);
3035}
3036
3038KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
3039 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
3040 getNumGangsSegments(), deviceType);
3041}
3042
3043bool acc::KernelsOp::hasWaitOnly() {
3044 return hasWaitOnly(mlir::acc::DeviceType::None);
3045}
3046
3047bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3048 return hasDeviceType(getWaitOnly(), deviceType);
3049}
3050
3051mlir::Operation::operand_range KernelsOp::getWaitValues() {
3052 return getWaitValues(mlir::acc::DeviceType::None);
3053}
3054
3056KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3058 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3059 getHasWaitDevnum(), deviceType);
3060}
3061
3062mlir::Value KernelsOp::getWaitDevnum() {
3063 return getWaitDevnum(mlir::acc::DeviceType::None);
3064}
3065
3066mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3067 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3068 getWaitOperandsSegments(), getHasWaitDevnum(),
3069 deviceType);
3070}
3071
3072LogicalResult acc::KernelsOp::verify() {
3074 *this, getNumGangs(), getNumGangsSegmentsAttr(),
3075 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
3076 return failure();
3077
3079 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
3080 getWaitOperandsDeviceTypeAttr(), "wait")))
3081 return failure();
3082
3083 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
3084 getNumWorkersDeviceTypeAttr(),
3085 "num_workers")))
3086 return failure();
3087
3088 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
3089 getVectorLengthDeviceTypeAttr(),
3090 "vector_length")))
3091 return failure();
3092
3094 getAsyncOperandsDeviceTypeAttr(),
3095 "async")))
3096 return failure();
3097
3099 return failure();
3100
3101 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
3102}
3103
3104void acc::KernelsOp::addPrivatization(MLIRContext *context,
3105 mlir::acc::PrivateOp op,
3106 mlir::acc::PrivateRecipeOp recipe) {
3107 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3108 getPrivateOperandsMutable().append(op.getResult());
3109}
3110
3111void acc::KernelsOp::addFirstPrivatization(
3112 MLIRContext *context, mlir::acc::FirstprivateOp op,
3113 mlir::acc::FirstprivateRecipeOp recipe) {
3114 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3115 getFirstprivateOperandsMutable().append(op.getResult());
3116}
3117
3118void acc::KernelsOp::addReduction(MLIRContext *context,
3119 mlir::acc::ReductionOp op,
3120 mlir::acc::ReductionRecipeOp recipe) {
3121 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3122 getReductionOperandsMutable().append(op.getResult());
3123}
3124
3125void acc::KernelsOp::addNumWorkersOperand(
3126 MLIRContext *context, mlir::Value newValue,
3127 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3128 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3129 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3130 getNumWorkersMutable()));
3131}
3132
3133void acc::KernelsOp::addVectorLengthOperand(
3134 MLIRContext *context, mlir::Value newValue,
3135 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3136 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3137 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3138 getVectorLengthMutable()));
3139}
3140void acc::KernelsOp::addAsyncOnly(
3141 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3142 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3143 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3144}
3145
3146void acc::KernelsOp::addAsyncOperand(
3147 MLIRContext *context, mlir::Value newValue,
3148 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3149 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3150 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3151 getAsyncOperandsMutable()));
3152}
3153
3154void acc::KernelsOp::addNumGangsOperands(
3155 MLIRContext *context, mlir::ValueRange newValues,
3156 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3158 if (getNumGangsSegmentsAttr())
3159 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3160
3161 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3162 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3163 getNumGangsMutable(), segments));
3164
3165 setNumGangsSegments(segments);
3166}
3167
3168void acc::KernelsOp::addWaitOnly(
3169 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3170 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3171 effectiveDeviceTypes));
3172}
3173void acc::KernelsOp::addWaitOperands(
3174 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3175 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3176
3178 if (getWaitOperandsSegments())
3179 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3180
3181 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3182 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3183 getWaitOperandsMutable(), segments));
3184 setWaitOperandsSegments(segments);
3185
3187 if (getHasWaitDevnumAttr())
3188 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3189 hasDevnums.insert(
3190 hasDevnums.end(),
3191 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3192 mlir::BoolAttr::get(context, hasDevnum));
3193 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3194}
3195
3196//===----------------------------------------------------------------------===//
3197// HostDataOp
3198//===----------------------------------------------------------------------===//
3199
3200LogicalResult acc::HostDataOp::verify() {
3201 if (getDataClauseOperands().empty())
3202 return emitError("at least one operand must appear on the host_data "
3203 "operation");
3204
3206 for (mlir::Value operand : getDataClauseOperands()) {
3207 auto useDeviceOp =
3208 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3209 if (!useDeviceOp)
3210 return emitError("expect data entry operation as defining op");
3211
3212 // Check for duplicate use_device clauses
3213 if (!seenVars.insert(useDeviceOp.getVar()).second)
3214 return emitError("duplicate use_device variable");
3215 }
3216 return success();
3217}
3218
3219void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3220 MLIRContext *context) {
3221 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3222}
3223
3224//===----------------------------------------------------------------------===//
3225// KernelEnvironmentOp
3226//===----------------------------------------------------------------------===//
3227
3228void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
3229 RewritePatternSet &results, MLIRContext *context) {
3230 results.add<RemoveEmptyKernelEnvironment>(context);
3231}
3232
3233//===----------------------------------------------------------------------===//
3234// LoopOp
3235//===----------------------------------------------------------------------===//
3236
3237static ParseResult parseGangValue(
3238 OpAsmParser &parser, llvm::StringRef keyword,
3241 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
3242 bool &needCommaBetweenValues, bool &newValue) {
3243 if (succeeded(parser.parseOptionalKeyword(keyword))) {
3244 if (parser.parseEqual())
3245 return failure();
3246 if (parser.parseOperand(operands.emplace_back()) ||
3247 parser.parseColonType(types.emplace_back()))
3248 return failure();
3249 attributes.push_back(gangArgType);
3250 needCommaBetweenValues = true;
3251 newValue = true;
3252 }
3253 return success();
3254}
3255
3256static ParseResult parseGangClause(
3257 OpAsmParser &parser,
3259 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
3260 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
3261 mlir::ArrayAttr &gangOnlyDeviceType) {
3262 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
3263 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
3264 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
3266 bool needCommaBetweenValues = false;
3267 bool needCommaBeforeOperands = false;
3268
3269 if (failed(parser.parseOptionalLParen())) {
3270 // Gang only keyword
3271 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3272 parser.getContext(), mlir::acc::DeviceType::None));
3273 gangOnlyDeviceType =
3274 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
3275 return success();
3276 }
3277
3278 // Parse gang only attributes
3279 if (succeeded(parser.parseOptionalLSquare())) {
3280 // Parse gang only attributes
3281 if (failed(parser.parseCommaSeparatedList([&]() {
3282 if (parser.parseAttribute(
3283 gangOnlyDeviceTypeAttributes.emplace_back()))
3284 return failure();
3285 return success();
3286 })))
3287 return failure();
3288 if (parser.parseRSquare())
3289 return failure();
3290 needCommaBeforeOperands = true;
3291 }
3292
3293 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3294 mlir::acc::GangArgType::Num);
3295 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3296 mlir::acc::GangArgType::Dim);
3297 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3298 parser.getContext(), mlir::acc::GangArgType::Static);
3299
3300 do {
3301 if (needCommaBeforeOperands) {
3302 needCommaBeforeOperands = false;
3303 continue;
3304 }
3305
3306 if (failed(parser.parseLBrace()))
3307 return failure();
3308
3309 int32_t crtOperandsSize = gangOperands.size();
3310 while (true) {
3311 bool newValue = false;
3312 bool needValue = false;
3313 if (needCommaBetweenValues) {
3314 if (succeeded(parser.parseOptionalComma()))
3315 needValue = true; // expect a new value after comma.
3316 else
3317 break;
3318 }
3319
3320 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
3321 gangOperands, gangOperandsType,
3322 gangArgTypeAttributes, argNum,
3323 needCommaBetweenValues, newValue)))
3324 return failure();
3325 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
3326 gangOperands, gangOperandsType,
3327 gangArgTypeAttributes, argDim,
3328 needCommaBetweenValues, newValue)))
3329 return failure();
3330 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3331 gangOperands, gangOperandsType,
3332 gangArgTypeAttributes, argStatic,
3333 needCommaBetweenValues, newValue)))
3334 return failure();
3335
3336 if (!newValue && needValue) {
3337 parser.emitError(parser.getCurrentLocation(),
3338 "new value expected after comma");
3339 return failure();
3340 }
3341
3342 if (!newValue)
3343 break;
3344 }
3345
3346 if (gangOperands.empty())
3347 return parser.emitError(
3348 parser.getCurrentLocation(),
3349 "expect at least one of num, dim or static values");
3350
3351 if (failed(parser.parseRBrace()))
3352 return failure();
3353
3354 if (succeeded(parser.parseOptionalLSquare())) {
3355 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
3356 parser.parseRSquare())
3357 return failure();
3358 } else {
3359 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3360 parser.getContext(), mlir::acc::DeviceType::None));
3361 }
3362
3363 seg.push_back(gangOperands.size() - crtOperandsSize);
3364
3365 } while (succeeded(parser.parseOptionalComma()));
3366
3367 if (failed(parser.parseRParen()))
3368 return failure();
3369
3370 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
3371 gangArgTypeAttributes.end());
3372 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
3373 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
3374
3376 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3377 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
3378
3379 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
3380 return success();
3381}
3382
3384 mlir::OperandRange operands, mlir::TypeRange types,
3385 std::optional<mlir::ArrayAttr> gangArgTypes,
3386 std::optional<mlir::ArrayAttr> deviceTypes,
3387 std::optional<mlir::DenseI32ArrayAttr> segments,
3388 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3389
3390 if (operands.begin() == operands.end() &&
3391 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
3392 return;
3393 }
3394
3395 p << "(";
3396
3397 printDeviceTypes(p, gangOnlyDeviceTypes);
3398
3399 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
3400 hasDeviceTypeValues(deviceTypes))
3401 p << ", ";
3402
3403 if (hasDeviceTypeValues(deviceTypes)) {
3404 unsigned opIdx = 0;
3405 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
3406 p << "{";
3407 llvm::interleaveComma(
3408 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
3409 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3410 (*gangArgTypes)[opIdx]);
3411 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3412 p << LoopOp::getGangNumKeyword();
3413 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3414 p << LoopOp::getGangDimKeyword();
3415 else if (gangArgTypeAttr.getValue() ==
3416 mlir::acc::GangArgType::Static)
3417 p << LoopOp::getGangStaticKeyword();
3418 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
3419 ++opIdx;
3420 });
3421 p << "}";
3422 printSingleDeviceType(p, it.value());
3423 });
3424 }
3425 p << ")";
3426}
3427
3429 std::optional<mlir::ArrayAttr> segments,
3430 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3431 if (!segments)
3432 return false;
3433 for (auto attr : *segments) {
3434 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3435 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3436 return true;
3437 }
3438 return false;
3439}
3440
3441/// Check for duplicates in the DeviceType array attribute.
3442/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3443static std::optional<mlir::acc::DeviceType>
3444checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3445 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3446 if (!deviceTypes)
3447 return std::nullopt;
3448 for (auto attr : deviceTypes) {
3449 auto deviceTypeAttr =
3450 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3451 if (!deviceTypeAttr)
3452 return mlir::acc::DeviceType::None;
3453 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3454 return deviceTypeAttr.getValue();
3455 }
3456 return std::nullopt;
3457}
3458
3459LogicalResult acc::LoopOp::verify() {
3460 if (getUpperbound().size() != getStep().size())
3461 return emitError() << "number of upperbounds expected to be the same as "
3462 "number of steps";
3463
3464 if (getUpperbound().size() != getLowerbound().size())
3465 return emitError() << "number of upperbounds expected to be the same as "
3466 "number of lowerbounds";
3467
3468 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3469 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3470 return emitError() << "inclusiveUpperbound size is expected to be the same"
3471 << " as upperbound size";
3472
3473 // Check collapse
3474 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3475 return emitOpError() << "collapse device_type attr must be define when"
3476 << " collapse attr is present";
3477
3478 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3479 getCollapseAttr().getValue().size() !=
3480 getCollapseDeviceTypeAttr().getValue().size())
3481 return emitOpError() << "collapse attribute count must match collapse"
3482 << " device_type count";
3483 if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3484 return emitOpError() << "duplicate device_type `"
3485 << acc::stringifyDeviceType(*duplicateDeviceType)
3486 << "` found in collapseDeviceType attribute";
3487
3488 // Check gang
3489 if (!getGangOperands().empty()) {
3490 if (!getGangOperandsArgType())
3491 return emitOpError() << "gangOperandsArgType attribute must be defined"
3492 << " when gang operands are present";
3493
3494 if (getGangOperands().size() !=
3495 getGangOperandsArgTypeAttr().getValue().size())
3496 return emitOpError() << "gangOperandsArgType attribute count must match"
3497 << " gangOperands count";
3498 }
3499 if (getGangAttr()) {
3500 if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3501 return emitOpError() << "duplicate device_type `"
3502 << acc::stringifyDeviceType(*duplicateDeviceType)
3503 << "` found in gang attribute";
3504 }
3505
3507 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3508 getGangOperandsDeviceTypeAttr(), "gang")))
3509 return failure();
3510
3511 // Check worker
3512 if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3513 return emitOpError() << "duplicate device_type `"
3514 << acc::stringifyDeviceType(*duplicateDeviceType)
3515 << "` found in worker attribute";
3516 if (auto duplicateDeviceType =
3517 checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3518 return emitOpError() << "duplicate device_type `"
3519 << acc::stringifyDeviceType(*duplicateDeviceType)
3520 << "` found in workerNumOperandsDeviceType attribute";
3521 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3522 getWorkerNumOperandsDeviceTypeAttr(),
3523 "worker")))
3524 return failure();
3525
3526 // Check vector
3527 if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3528 return emitOpError() << "duplicate device_type `"
3529 << acc::stringifyDeviceType(*duplicateDeviceType)
3530 << "` found in vector attribute";
3531 if (auto duplicateDeviceType =
3532 checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3533 return emitOpError() << "duplicate device_type `"
3534 << acc::stringifyDeviceType(*duplicateDeviceType)
3535 << "` found in vectorOperandsDeviceType attribute";
3536 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3537 getVectorOperandsDeviceTypeAttr(),
3538 "vector")))
3539 return failure();
3540
3542 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3543 getTileOperandsDeviceTypeAttr(), "tile")))
3544 return failure();
3545
3546 // auto, independent and seq attribute are mutually exclusive.
3547 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3548 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3549 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3550 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3551 return emitError() << "only one of auto, independent, seq can be present "
3552 "at the same time";
3553 }
3554
3555 // Check that at least one of auto, independent, or seq is present
3556 // for the device-independent default clauses.
3557 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3558 return attr.getValue() == mlir::acc::DeviceType::None;
3559 };
3560 bool hasDefaultSeq =
3561 getSeqAttr()
3562 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3563 hasDeviceNone)
3564 : false;
3565 bool hasDefaultIndependent =
3566 getIndependentAttr()
3567 ? llvm::any_of(
3568 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3569 hasDeviceNone)
3570 : false;
3571 bool hasDefaultAuto =
3572 getAuto_Attr()
3573 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3574 hasDeviceNone)
3575 : false;
3576 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3577 return emitError()
3578 << "at least one of auto, independent, seq must be present";
3579 }
3580
3581 // Gang, worker and vector are incompatible with seq.
3582 if (getSeqAttr()) {
3583 for (auto attr : getSeqAttr()) {
3584 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3585 if (hasVector(deviceTypeAttr.getValue()) ||
3586 getVectorValue(deviceTypeAttr.getValue()) ||
3587 hasWorker(deviceTypeAttr.getValue()) ||
3588 getWorkerValue(deviceTypeAttr.getValue()) ||
3589 hasGang(deviceTypeAttr.getValue()) ||
3590 getGangValue(mlir::acc::GangArgType::Num,
3591 deviceTypeAttr.getValue()) ||
3592 getGangValue(mlir::acc::GangArgType::Dim,
3593 deviceTypeAttr.getValue()) ||
3594 getGangValue(mlir::acc::GangArgType::Static,
3595 deviceTypeAttr.getValue()))
3596 return emitError() << "gang, worker or vector cannot appear with seq";
3597 }
3598 }
3599
3600 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
3601 mlir::acc::PrivateRecipeOp>(
3602 *this, getPrivateOperands(), "private")))
3603 return failure();
3604
3605 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
3606 mlir::acc::FirstprivateRecipeOp>(
3607 *this, getFirstprivateOperands(), "firstprivate")))
3608 return failure();
3609
3610 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
3611 mlir::acc::ReductionRecipeOp>(
3612 *this, getReductionOperands(), "reduction")))
3613 return failure();
3614
3615 if (getCombined().has_value() &&
3616 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3617 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3618 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3619 return emitError("unexpected combined constructs attribute");
3620 }
3621
3622 // Check non-empty body().
3623 if (getRegion().empty())
3624 return emitError("expected non-empty body.");
3625
3626 if (getUnstructured()) {
3627 if (!isContainerLike())
3628 return emitError(
3629 "unstructured acc.loop must not have induction variables");
3630 } else if (isContainerLike()) {
3631 // When it is container-like - it is expected to hold a loop-like operation.
3632 // Obtain the maximum collapse count - we use this to check that there
3633 // are enough loops contained.
3634 uint64_t collapseCount = getCollapseValue().value_or(1);
3635 if (getCollapseAttr()) {
3636 for (auto collapseEntry : getCollapseAttr()) {
3637 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3638 if (intAttr.getValue().getZExtValue() > collapseCount)
3639 collapseCount = intAttr.getValue().getZExtValue();
3640 }
3641 }
3642
3643 // We want to check that we find enough loop-like operations inside.
3644 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3645 // level.
3646 mlir::Operation *expectedParent = this->getOperation();
3647 bool foundSibling = false;
3648 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3649 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3650 // This effectively checks that we are not looking at a sibling loop.
3651 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3652 expectedParent) {
3653 foundSibling = true;
3655 }
3656
3657 collapseCount--;
3658 expectedParent = op;
3659 }
3660 // We found enough contained loops.
3661 if (collapseCount == 0)
3664 });
3665
3666 if (foundSibling)
3667 return emitError("found sibling loops inside container-like acc.loop");
3668 if (collapseCount != 0)
3669 return emitError("failed to find enough loop-like operations inside "
3670 "container-like acc.loop");
3671 }
3672
3673 return success();
3674}
3675
3676unsigned LoopOp::getNumDataOperands() {
3677 return getReductionOperands().size() + getPrivateOperands().size() +
3678 getFirstprivateOperands().size();
3679}
3680
3681Value LoopOp::getDataOperand(unsigned i) {
3682 unsigned numOptional =
3683 getLowerbound().size() + getUpperbound().size() + getStep().size();
3684 numOptional += getGangOperands().size();
3685 numOptional += getVectorOperands().size();
3686 numOptional += getWorkerNumOperands().size();
3687 numOptional += getTileOperands().size();
3688 numOptional += getCacheOperands().size();
3689 return getOperand(numOptional + i);
3690}
3691
3692bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3693
3694bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3695 return hasDeviceType(getAuto_(), deviceType);
3696}
3697
3698bool LoopOp::hasIndependent() {
3699 return hasIndependent(mlir::acc::DeviceType::None);
3700}
3701
3702bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3703 return hasDeviceType(getIndependent(), deviceType);
3704}
3705
3706bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3707
3708bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3709 return hasDeviceType(getSeq(), deviceType);
3710}
3711
3712mlir::Value LoopOp::getVectorValue() {
3713 return getVectorValue(mlir::acc::DeviceType::None);
3714}
3715
3716mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3717 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3718 getVectorOperands(), deviceType);
3719}
3720
3721bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3722
3723bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3724 return hasDeviceType(getVector(), deviceType);
3725}
3726
3727mlir::Value LoopOp::getWorkerValue() {
3728 return getWorkerValue(mlir::acc::DeviceType::None);
3729}
3730
3731mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3732 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3733 getWorkerNumOperands(), deviceType);
3734}
3735
3736bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3737
3738bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3739 return hasDeviceType(getWorker(), deviceType);
3740}
3741
3742mlir::Operation::operand_range LoopOp::getTileValues() {
3743 return getTileValues(mlir::acc::DeviceType::None);
3744}
3745
3747LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3748 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3749 getTileOperandsSegments(), deviceType);
3750}
3751
3752std::optional<int64_t> LoopOp::getCollapseValue() {
3753 return getCollapseValue(mlir::acc::DeviceType::None);
3754}
3755
3756std::optional<int64_t>
3757LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3758 if (!getCollapseAttr())
3759 return std::nullopt;
3760 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3761 auto intAttr =
3762 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3763 return intAttr.getValue().getZExtValue();
3764 }
3765 return std::nullopt;
3766}
3767
3768mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3769 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3770}
3771
3772mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3773 mlir::acc::DeviceType deviceType) {
3774 if (getGangOperands().empty())
3775 return {};
3776 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3777 int32_t nbOperandsBefore = 0;
3778 for (unsigned i = 0; i < *pos; ++i)
3779 nbOperandsBefore += (*getGangOperandsSegments())[i];
3781 getGangOperands()
3782 .drop_front(nbOperandsBefore)
3783 .take_front((*getGangOperandsSegments())[*pos]);
3784
3785 int32_t argTypeIdx = nbOperandsBefore;
3786 for (auto value : values) {
3787 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3788 (*getGangOperandsArgType())[argTypeIdx]);
3789 if (gangArgTypeAttr.getValue() == gangArgType)
3790 return value;
3791 ++argTypeIdx;
3792 }
3793 }
3794 return {};
3795}
3796
3797bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3798
3799bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3800 return hasDeviceType(getGang(), deviceType);
3801}
3802
3803llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3804 return {&getRegion()};
3805}
3806
3807/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3808/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3809/// `(` ssa-id-and-type-list `)`
3810/// region
3811ParseResult
3814 SmallVectorImpl<Type> &lowerboundType,
3816 SmallVectorImpl<Type> &upperboundType,
3818 SmallVectorImpl<Type> &stepType) {
3819
3821 if (succeeded(
3822 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3823 if (parser.parseLParen() ||
3824 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3825 /*allowType=*/true) ||
3826 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3827 parser.parseOperandList(lowerbound, inductionVars.size(),
3829 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3830 parser.parseKeyword("to") || parser.parseLParen() ||
3831 parser.parseOperandList(upperbound, inductionVars.size(),
3833 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3834 parser.parseKeyword("step") || parser.parseLParen() ||
3835 parser.parseOperandList(step, inductionVars.size(),
3837 parser.parseColonTypeList(stepType) || parser.parseRParen())
3838 return failure();
3839 }
3840 return parser.parseRegion(region, inductionVars);
3841}
3842
3844 ValueRange lowerbound, TypeRange lowerboundType,
3845 ValueRange upperbound, TypeRange upperboundType,
3846 ValueRange steps, TypeRange stepType) {
3847 ValueRange regionArgs = region.front().getArguments();
3848 if (!regionArgs.empty()) {
3849 p << acc::LoopOp::getControlKeyword() << "(";
3850 llvm::interleaveComma(regionArgs, p,
3851 [&p](Value v) { p << v << " : " << v.getType(); });
3852 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3853 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3854 << " : " << stepType << ") ";
3855 }
3856 p.printRegion(region, /*printEntryBlockArgs=*/false);
3857}
3858
3859void acc::LoopOp::addSeq(MLIRContext *context,
3860 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3861 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3862 effectiveDeviceTypes));
3863}
3864
3865void acc::LoopOp::addIndependent(
3866 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3867 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3868 context, getIndependentAttr(), effectiveDeviceTypes));
3869}
3870
3871void acc::LoopOp::addAuto(MLIRContext *context,
3872 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3873 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3874 effectiveDeviceTypes));
3875}
3876
3877void acc::LoopOp::setCollapseForDeviceTypes(
3878 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3879 llvm::APInt value) {
3882
3883 assert((getCollapseAttr() == nullptr) ==
3884 (getCollapseDeviceTypeAttr() == nullptr));
3885 assert(value.getBitWidth() == 64);
3886
3887 if (getCollapseAttr()) {
3888 for (const auto &existing :
3889 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3890 newValues.push_back(std::get<0>(existing));
3891 newDeviceTypes.push_back(std::get<1>(existing));
3892 }
3893 }
3894
3895 if (effectiveDeviceTypes.empty()) {
3896 // If the effective device-types list is empty, this is before there are any
3897 // being applied by device_type, so this should be added as a 'none'.
3898 newValues.push_back(
3899 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3900 newDeviceTypes.push_back(
3901 acc::DeviceTypeAttr::get(context, DeviceType::None));
3902 } else {
3903 for (DeviceType dt : effectiveDeviceTypes) {
3904 newValues.push_back(
3905 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3906 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3907 }
3908 }
3909
3910 setCollapseAttr(ArrayAttr::get(context, newValues));
3911 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3912}
3913
3914void acc::LoopOp::setTileForDeviceTypes(
3915 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3916 ValueRange values) {
3918 if (getTileOperandsSegments())
3919 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3920
3921 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3922 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3923 getTileOperandsMutable(), segments));
3924
3925 setTileOperandsSegments(segments);
3926}
3927
3928void acc::LoopOp::addVectorOperand(
3929 MLIRContext *context, mlir::Value newValue,
3930 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3931 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3932 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3933 newValue, getVectorOperandsMutable()));
3934}
3935
3936void acc::LoopOp::addEmptyVector(
3937 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3938 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3939 effectiveDeviceTypes));
3940}
3941
3942void acc::LoopOp::addWorkerNumOperand(
3943 MLIRContext *context, mlir::Value newValue,
3944 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3945 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3946 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3947 newValue, getWorkerNumOperandsMutable()));
3948}
3949
3950void acc::LoopOp::addEmptyWorker(
3951 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3952 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3953 effectiveDeviceTypes));
3954}
3955
3956void acc::LoopOp::addEmptyGang(
3957 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3958 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3959 effectiveDeviceTypes));
3960}
3961
3962bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3963 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3964 return attr.getValue() == dt;
3965 };
3966 auto testFromArr = [=](ArrayAttr arr) -> bool {
3967 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3968 };
3969
3970 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3971 return true;
3972 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3973 return true;
3974 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3975 return true;
3976
3977 return false;
3978}
3979
3980bool acc::LoopOp::hasDefaultGangWorkerVector() {
3981 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3982 hasGang() || getGangValue(GangArgType::Num) ||
3983 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3984}
3985
3986acc::LoopParMode
3987acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3988 if (hasSeq(deviceType))
3989 return LoopParMode::loop_seq;
3990 if (hasAuto(deviceType))
3991 return LoopParMode::loop_auto;
3992 if (hasIndependent(deviceType))
3993 return LoopParMode::loop_independent;
3994 if (hasSeq())
3995 return LoopParMode::loop_seq;
3996 if (hasAuto())
3997 return LoopParMode::loop_auto;
3998 assert(hasIndependent() &&
3999 "loop must have default auto, seq, or independent");
4000 return LoopParMode::loop_independent;
4001}
4002
4003void acc::LoopOp::addGangOperands(
4004 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4007 if (std::optional<ArrayRef<int32_t>> existingSegments =
4008 getGangOperandsSegments())
4009 llvm::copy(*existingSegments, std::back_inserter(segments));
4010
4011 unsigned beforeCount = segments.size();
4012
4013 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4014 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
4015 getGangOperandsMutable(), segments));
4016
4017 setGangOperandsSegments(segments);
4018
4019 // This is a bit of extra work to make sure we update the 'types' correctly by
4020 // adding to the types collection the correct number of times. We could
4021 // potentially add something similar to the
4022 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
4023 // excessive for a one-off case.
4024 unsigned numAdded = segments.size() - beforeCount;
4025
4026 if (numAdded > 0) {
4028 if (getGangOperandsArgTypeAttr())
4029 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
4030
4031 for (auto i : llvm::index_range(0u, numAdded)) {
4032 llvm::transform(argTypes, std::back_inserter(gangTypes),
4033 [=](mlir::acc::GangArgType gangTy) {
4034 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
4035 });
4036 (void)i;
4037 }
4038
4039 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
4040 }
4041}
4042
4043void acc::LoopOp::addPrivatization(MLIRContext *context,
4044 mlir::acc::PrivateOp op,
4045 mlir::acc::PrivateRecipeOp recipe) {
4046 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4047 getPrivateOperandsMutable().append(op.getResult());
4048}
4049
4050void acc::LoopOp::addFirstPrivatization(
4051 MLIRContext *context, mlir::acc::FirstprivateOp op,
4052 mlir::acc::FirstprivateRecipeOp recipe) {
4053 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4054 getFirstprivateOperandsMutable().append(op.getResult());
4055}
4056
4057void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
4058 mlir::acc::ReductionRecipeOp recipe) {
4059 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
4060 getReductionOperandsMutable().append(op.getResult());
4061}
4062
4063//===----------------------------------------------------------------------===//
4064// DataOp
4065//===----------------------------------------------------------------------===//
4066
4067LogicalResult acc::DataOp::verify() {
4068 // 2.6.5. Data Construct restriction
4069 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
4070 // attach, or default clause must appear on a data construct.
4071 if (getOperands().empty() && !getDefaultAttr())
4072 return emitError("at least one operand or the default attribute "
4073 "must appear on the data operation");
4074
4075 for (mlir::Value operand : getDataClauseOperands())
4076 if (isa<BlockArgument>(operand) ||
4077 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4078 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
4079 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
4080 operand.getDefiningOp()))
4081 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4082 "as defining op");
4083
4085 return failure();
4086
4087 return success();
4088}
4089
4090unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
4091
4092Value DataOp::getDataOperand(unsigned i) {
4093 unsigned numOptional = getIfCond() ? 1 : 0;
4094 numOptional += getAsyncOperands().size() ? 1 : 0;
4095 numOptional += getWaitOperands().size();
4096 return getOperand(numOptional + i);
4097}
4098
4099bool acc::DataOp::hasAsyncOnly() {
4100 return hasAsyncOnly(mlir::acc::DeviceType::None);
4101}
4102
4103bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4104 return hasDeviceType(getAsyncOnly(), deviceType);
4105}
4106
4107mlir::Value DataOp::getAsyncValue() {
4108 return getAsyncValue(mlir::acc::DeviceType::None);
4109}
4110
4111mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4113 getAsyncOperands(), deviceType);
4114}
4115
4116bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
4117
4118bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4119 return hasDeviceType(getWaitOnly(), deviceType);
4120}
4121
4122mlir::Operation::operand_range DataOp::getWaitValues() {
4123 return getWaitValues(mlir::acc::DeviceType::None);
4124}
4125
4127DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4129 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4130 getHasWaitDevnum(), deviceType);
4131}
4132
4133mlir::Value DataOp::getWaitDevnum() {
4134 return getWaitDevnum(mlir::acc::DeviceType::None);
4135}
4136
4137mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4138 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4139 getWaitOperandsSegments(), getHasWaitDevnum(),
4140 deviceType);
4141}
4142
4143void acc::DataOp::addAsyncOnly(
4144 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4145 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4146 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4147}
4148
4149void acc::DataOp::addAsyncOperand(
4150 MLIRContext *context, mlir::Value newValue,
4151 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4152 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4153 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4154 getAsyncOperandsMutable()));
4155}
4156
4157void acc::DataOp::addWaitOnly(MLIRContext *context,
4158 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4159 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4160 effectiveDeviceTypes));
4161}
4162
4163void acc::DataOp::addWaitOperands(
4164 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4165 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4166
4168 if (getWaitOperandsSegments())
4169 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4170
4171 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4172 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4173 getWaitOperandsMutable(), segments));
4174 setWaitOperandsSegments(segments);
4175
4177 if (getHasWaitDevnumAttr())
4178 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4179 hasDevnums.insert(
4180 hasDevnums.end(),
4181 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4182 mlir::BoolAttr::get(context, hasDevnum));
4183 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4184}
4185
4186//===----------------------------------------------------------------------===//
4187// ExitDataOp
4188//===----------------------------------------------------------------------===//
4189
4190LogicalResult acc::ExitDataOp::verify() {
4191 // 2.6.6. Data Exit Directive restriction
4192 // At least one copyout, delete, or detach clause must appear on an exit data
4193 // directive.
4194 if (getDataClauseOperands().empty())
4195 return emitError("at least one operand must be present in dataOperands on "
4196 "the exit data operation");
4197
4198 // The async attribute represent the async clause without value. Therefore the
4199 // attribute and operand cannot appear at the same time.
4200 if (getAsyncOperand() && getAsync())
4201 return emitError("async attribute cannot appear with asyncOperand");
4202
4203 // The wait attribute represent the wait clause without values. Therefore the
4204 // attribute and operands cannot appear at the same time.
4205 if (!getWaitOperands().empty() && getWait())
4206 return emitError("wait attribute cannot appear with waitOperands");
4207
4208 if (getWaitDevnum() && getWaitOperands().empty())
4209 return emitError("wait_devnum cannot appear without waitOperands");
4210
4211 return success();
4212}
4213
4214unsigned ExitDataOp::getNumDataOperands() {
4215 return getDataClauseOperands().size();
4216}
4217
4218Value ExitDataOp::getDataOperand(unsigned i) {
4219 unsigned numOptional = getIfCond() ? 1 : 0;
4220 numOptional += getAsyncOperand() ? 1 : 0;
4221 numOptional += getWaitDevnum() ? 1 : 0;
4222 return getOperand(getWaitOperands().size() + numOptional + i);
4223}
4224
4225void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4226 MLIRContext *context) {
4227 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
4228}
4229
4230void ExitDataOp::addAsyncOnly(MLIRContext *context,
4231 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4232 assert(effectiveDeviceTypes.empty());
4233 assert(!getAsyncAttr());
4234 assert(!getAsyncOperand());
4235
4236 setAsyncAttr(mlir::UnitAttr::get(context));
4237}
4238
4239void ExitDataOp::addAsyncOperand(
4240 MLIRContext *context, mlir::Value newValue,
4241 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4242 assert(effectiveDeviceTypes.empty());
4243 assert(!getAsyncAttr());
4244 assert(!getAsyncOperand());
4245
4246 getAsyncOperandMutable().append(newValue);
4247}
4248
4249void ExitDataOp::addWaitOnly(MLIRContext *context,
4250 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4251 assert(effectiveDeviceTypes.empty());
4252 assert(!getWaitAttr());
4253 assert(getWaitOperands().empty());
4254 assert(!getWaitDevnum());
4255
4256 setWaitAttr(mlir::UnitAttr::get(context));
4257}
4258
4259void ExitDataOp::addWaitOperands(
4260 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4261 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4262 assert(effectiveDeviceTypes.empty());
4263 assert(!getWaitAttr());
4264 assert(getWaitOperands().empty());
4265 assert(!getWaitDevnum());
4266
4267 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4268 // operands list.
4269 if (hasDevnum) {
4270 getWaitDevnumMutable().append(newValues.front());
4271 newValues = newValues.drop_front();
4272 }
4273
4274 getWaitOperandsMutable().append(newValues);
4275}
4276
4277//===----------------------------------------------------------------------===//
4278// EnterDataOp
4279//===----------------------------------------------------------------------===//
4280
4281LogicalResult acc::EnterDataOp::verify() {
4282 // 2.6.6. Data Enter Directive restriction
4283 // At least one copyin, create, or attach clause must appear on an enter data
4284 // directive.
4285 if (getDataClauseOperands().empty())
4286 return emitError("at least one operand must be present in dataOperands on "
4287 "the enter data operation");
4288
4289 // The async attribute represent the async clause without value. Therefore the
4290 // attribute and operand cannot appear at the same time.
4291 if (getAsyncOperand() && getAsync())
4292 return emitError("async attribute cannot appear with asyncOperand");
4293
4294 // The wait attribute represent the wait clause without values. Therefore the
4295 // attribute and operands cannot appear at the same time.
4296 if (!getWaitOperands().empty() && getWait())
4297 return emitError("wait attribute cannot appear with waitOperands");
4298
4299 if (getWaitDevnum() && getWaitOperands().empty())
4300 return emitError("wait_devnum cannot appear without waitOperands");
4301
4302 for (mlir::Value operand : getDataClauseOperands())
4303 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4304 operand.getDefiningOp()))
4305 return emitError("expect data entry operation as defining op");
4306
4307 return success();
4308}
4309
4310unsigned EnterDataOp::getNumDataOperands() {
4311 return getDataClauseOperands().size();
4312}
4313
4314Value EnterDataOp::getDataOperand(unsigned i) {
4315 unsigned numOptional = getIfCond() ? 1 : 0;
4316 numOptional += getAsyncOperand() ? 1 : 0;
4317 numOptional += getWaitDevnum() ? 1 : 0;
4318 return getOperand(getWaitOperands().size() + numOptional + i);
4319}
4320
4321void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4322 MLIRContext *context) {
4323 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
4324}
4325
4326void EnterDataOp::addAsyncOnly(
4327 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4328 assert(effectiveDeviceTypes.empty());
4329 assert(!getAsyncAttr());
4330 assert(!getAsyncOperand());
4331
4332 setAsyncAttr(mlir::UnitAttr::get(context));
4333}
4334
4335void EnterDataOp::addAsyncOperand(
4336 MLIRContext *context, mlir::Value newValue,
4337 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4338 assert(effectiveDeviceTypes.empty());
4339 assert(!getAsyncAttr());
4340 assert(!getAsyncOperand());
4341
4342 getAsyncOperandMutable().append(newValue);
4343}
4344
4345void EnterDataOp::addWaitOnly(MLIRContext *context,
4346 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4347 assert(effectiveDeviceTypes.empty());
4348 assert(!getWaitAttr());
4349 assert(getWaitOperands().empty());
4350 assert(!getWaitDevnum());
4351
4352 setWaitAttr(mlir::UnitAttr::get(context));
4353}
4354
4355void EnterDataOp::addWaitOperands(
4356 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4357 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4358 assert(effectiveDeviceTypes.empty());
4359 assert(!getWaitAttr());
4360 assert(getWaitOperands().empty());
4361 assert(!getWaitDevnum());
4362
4363 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4364 // operands list.
4365 if (hasDevnum) {
4366 getWaitDevnumMutable().append(newValues.front());
4367 newValues = newValues.drop_front();
4368 }
4369
4370 getWaitOperandsMutable().append(newValues);
4371}
4372
4373//===----------------------------------------------------------------------===//
4374// AtomicReadOp
4375//===----------------------------------------------------------------------===//
4376
4377LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
4378
4379//===----------------------------------------------------------------------===//
4380// AtomicWriteOp
4381//===----------------------------------------------------------------------===//
4382
4383LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
4384
4385//===----------------------------------------------------------------------===//
4386// AtomicUpdateOp
4387//===----------------------------------------------------------------------===//
4388
4389LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4390 PatternRewriter &rewriter) {
4391 if (op.isNoOp()) {
4392 rewriter.eraseOp(op);
4393 return success();
4394 }
4395
4396 if (Value writeVal = op.getWriteOpVal()) {
4397 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
4398 op.getIfCond());
4399 return success();
4400 }
4401
4402 return failure();
4403}
4404
4405LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
4406
4407LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4408
4409//===----------------------------------------------------------------------===//
4410// AtomicCaptureOp
4411//===----------------------------------------------------------------------===//
4412
4413AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4414 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4415 return op;
4416 return dyn_cast<AtomicReadOp>(getSecondOp());
4417}
4418
4419AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4420 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4421 return op;
4422 return dyn_cast<AtomicWriteOp>(getSecondOp());
4423}
4424
4425AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4426 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4427 return op;
4428 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4429}
4430
4431LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
4432
4433//===----------------------------------------------------------------------===//
4434// DeclareEnterOp
4435//===----------------------------------------------------------------------===//
4436
4437template <typename Op>
4438static LogicalResult
4440 bool requireAtLeastOneOperand = true) {
4441 if (operands.empty() && requireAtLeastOneOperand)
4442 return emitError(
4443 op->getLoc(),
4444 "at least one operand must appear on the declare operation");
4445
4446 for (mlir::Value operand : operands) {
4447 if (isa<BlockArgument>(operand) ||
4448 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4449 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4450 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4451 operand.getDefiningOp()))
4452 return op.emitError(
4453 "expect valid declare data entry operation or acc.getdeviceptr "
4454 "as defining op");
4455
4456 mlir::Value var{getVar(operand.getDefiningOp())};
4457 assert(var && "declare operands can only be data entry operations which "
4458 "must have var");
4459 (void)var;
4460 std::optional<mlir::acc::DataClause> dataClauseOptional{
4461 getDataClause(operand.getDefiningOp())};
4462 assert(dataClauseOptional.has_value() &&
4463 "declare operands can only be data entry operations which must have "
4464 "dataClause");
4465 (void)dataClauseOptional;
4466 }
4467
4468 return success();
4469}
4470
4471LogicalResult acc::DeclareEnterOp::verify() {
4472 return checkDeclareOperands(*this, this->getDataClauseOperands());
4473}
4474
4475//===----------------------------------------------------------------------===//
4476// DeclareExitOp
4477//===----------------------------------------------------------------------===//
4478
4479LogicalResult acc::DeclareExitOp::verify() {
4480 if (getToken())
4481 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4482 /*requireAtLeastOneOperand=*/false);
4483 return checkDeclareOperands(*this, this->getDataClauseOperands());
4484}
4485
4486//===----------------------------------------------------------------------===//
4487// DeclareOp
4488//===----------------------------------------------------------------------===//
4489
4490LogicalResult acc::DeclareOp::verify() {
4491 return checkDeclareOperands(*this, this->getDataClauseOperands());
4492}
4493
4494//===----------------------------------------------------------------------===//
4495// RoutineOp
4496//===----------------------------------------------------------------------===//
4497
4498static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4499 acc::DeviceType dtype) {
4500 unsigned parallelism = 0;
4501 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4502 parallelism += op.hasWorker(dtype) ? 1 : 0;
4503 parallelism += op.hasVector(dtype) ? 1 : 0;
4504 parallelism += op.hasSeq(dtype) ? 1 : 0;
4505 return parallelism;
4506}
4507
4508LogicalResult acc::RoutineOp::verify() {
4509 unsigned baseParallelism =
4510 getParallelismForDeviceType(*this, acc::DeviceType::None);
4511
4512 if (baseParallelism > 1)
4513 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4514 "be present at the same time";
4515
4516 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4517 ++dtypeInt) {
4518 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4519 if (dtype == acc::DeviceType::None)
4520 continue;
4521 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4522
4523 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4524 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4525 "be present at the same time for device_type `"
4526 << acc::stringifyDeviceType(dtype) << "`";
4527 }
4528
4529 return success();
4530}
4531
4532static ParseResult parseBindName(OpAsmParser &parser,
4533 mlir::ArrayAttr &bindIdName,
4534 mlir::ArrayAttr &bindStrName,
4535 mlir::ArrayAttr &deviceIdTypes,
4536 mlir::ArrayAttr &deviceStrTypes) {
4537 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4538 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4539 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4540 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4541
4542 if (failed(parser.parseCommaSeparatedList([&]() {
4543 mlir::Attribute newAttr;
4544 bool isSymbolRefAttr;
4545 auto parseResult = parser.parseAttribute(newAttr);
4546 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4547 bindIdNameAttrs.push_back(symbolRefAttr);
4548 isSymbolRefAttr = true;
4549 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4550 bindStrNameAttrs.push_back(stringAttr);
4551 isSymbolRefAttr = false;
4552 }
4553 if (parseResult)
4554 return failure();
4555 if (failed(parser.parseOptionalLSquare())) {
4556 if (isSymbolRefAttr) {
4557 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4558 parser.getContext(), mlir::acc::DeviceType::None));
4559 } else {
4560 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4561 parser.getContext(), mlir::acc::DeviceType::None));
4562 }
4563 } else {
4564 if (isSymbolRefAttr) {
4565 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4566 parser.parseRSquare())
4567 return failure();
4568 } else {
4569 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4570 parser.parseRSquare())
4571 return failure();
4572 }
4573 }
4574 return success();
4575 })))
4576 return failure();
4577
4578 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4579 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4580 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4581 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4582
4583 return success();
4584}
4585
4587 std::optional<mlir::ArrayAttr> bindIdName,
4588 std::optional<mlir::ArrayAttr> bindStrName,
4589 std::optional<mlir::ArrayAttr> deviceIdTypes,
4590 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4591 // Create combined vectors for all bind names and device types
4594
4595 // Append bindIdName and deviceIdTypes
4596 if (hasDeviceTypeValues(deviceIdTypes)) {
4597 allBindNames.append(bindIdName->begin(), bindIdName->end());
4598 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4599 }
4600
4601 // Append bindStrName and deviceStrTypes
4602 if (hasDeviceTypeValues(deviceStrTypes)) {
4603 allBindNames.append(bindStrName->begin(), bindStrName->end());
4604 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4605 }
4606
4607 // Print the combined sequence
4608 if (!allBindNames.empty())
4609 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4610 [&](const auto &pair) {
4611 p << std::get<0>(pair);
4612 printSingleDeviceType(p, std::get<1>(pair));
4613 });
4614}
4615
4616static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4617 mlir::ArrayAttr &gang,
4618 mlir::ArrayAttr &gangDim,
4619 mlir::ArrayAttr &gangDimDeviceTypes) {
4620
4621 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4622 gangDimDeviceTypeAttrs;
4623 bool needCommaBeforeOperands = false;
4624
4625 // Gang keyword only
4626 if (failed(parser.parseOptionalLParen())) {
4627 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4628 parser.getContext(), mlir::acc::DeviceType::None));
4629 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4630 return success();
4631 }
4632
4633 // Parse keyword only attributes
4634 if (succeeded(parser.parseOptionalLSquare())) {
4635 if (failed(parser.parseCommaSeparatedList([&]() {
4636 if (parser.parseAttribute(gangAttrs.emplace_back()))
4637 return failure();
4638 return success();
4639 })))
4640 return failure();
4641 if (parser.parseRSquare())
4642 return failure();
4643 needCommaBeforeOperands = true;
4644 }
4645
4646 if (needCommaBeforeOperands && failed(parser.parseComma()))
4647 return failure();
4648
4649 if (failed(parser.parseCommaSeparatedList([&]() {
4650 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4651 parser.parseColon() ||
4652 parser.parseAttribute(gangDimAttrs.emplace_back()))
4653 return failure();
4654 if (succeeded(parser.parseOptionalLSquare())) {
4655 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4656 parser.parseRSquare())
4657 return failure();
4658 } else {
4659 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4660 parser.getContext(), mlir::acc::DeviceType::None));
4661 }
4662 return success();
4663 })))
4664 return failure();
4665
4666 if (failed(parser.parseRParen()))
4667 return failure();
4668
4669 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4670 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4671 gangDimDeviceTypes =
4672 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4673
4674 return success();
4675}
4676
4678 std::optional<mlir::ArrayAttr> gang,
4679 std::optional<mlir::ArrayAttr> gangDim,
4680 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4681
4682 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4683 gang->size() == 1) {
4684 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4685 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4686 return;
4687 }
4688
4689 p << "(";
4690
4691 printDeviceTypes(p, gang);
4692
4693 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4694 p << ", ";
4695
4696 if (hasDeviceTypeValues(gangDimDeviceTypes))
4697 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4698 [&](const auto &pair) {
4699 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4700 p << std::get<0>(pair);
4701 printSingleDeviceType(p, std::get<1>(pair));
4702 });
4703
4704 p << ")";
4705}
4706
4707static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4708 mlir::ArrayAttr &deviceTypes) {
4710 // Keyword only
4711 if (failed(parser.parseOptionalLParen())) {
4712 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4713 parser.getContext(), mlir::acc::DeviceType::None));
4714 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4715 return success();
4716 }
4717
4718 // Parse device type attributes
4719 if (succeeded(parser.parseOptionalLSquare())) {
4720 if (failed(parser.parseCommaSeparatedList([&]() {
4721 if (parser.parseAttribute(attributes.emplace_back()))
4722 return failure();
4723 return success();
4724 })))
4725 return failure();
4726 if (parser.parseRSquare() || parser.parseRParen())
4727 return failure();
4728 }
4729 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4730 return success();
4731}
4732
4733static void
4735 std::optional<mlir::ArrayAttr> deviceTypes) {
4736
4737 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4738 auto deviceTypeAttr =
4739 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4740 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4741 return;
4742 }
4743
4744 if (!hasDeviceTypeValues(deviceTypes))
4745 return;
4746
4747 p << "([";
4748 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4749 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4750 p << dTypeAttr;
4751 });
4752 p << "])";
4753}
4754
4755bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4756
4757bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4758 return hasDeviceType(getWorker(), deviceType);
4759}
4760
4761bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4762
4763bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4764 return hasDeviceType(getVector(), deviceType);
4765}
4766
4767bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4768
4769bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4770 return hasDeviceType(getSeq(), deviceType);
4771}
4772
4773std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4774RoutineOp::getBindNameValue() {
4775 return getBindNameValue(mlir::acc::DeviceType::None);
4776}
4777
4778std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4779RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4780 if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
4781 !hasDeviceTypeValues(getBindStrNameDeviceType())) {
4782 return std::nullopt;
4783 }
4784
4785 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4786 auto attr = (*getBindIdName())[*pos];
4787 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4788 assert(symbolRefAttr && "expected SymbolRef");
4789 return symbolRefAttr;
4790 }
4791
4792 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4793 auto attr = (*getBindStrName())[*pos];
4794 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4795 assert(stringAttr && "expected String");
4796 return stringAttr;
4797 }
4798
4799 return std::nullopt;
4800}
4801
4802bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4803
4804bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4805 return hasDeviceType(getGang(), deviceType);
4806}
4807
4808std::optional<int64_t> RoutineOp::getGangDimValue() {
4809 return getGangDimValue(mlir::acc::DeviceType::None);
4810}
4811
4812std::optional<int64_t>
4813RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4814 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4815 return std::nullopt;
4816 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4817 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4818 return intAttr.getInt();
4819 }
4820 return std::nullopt;
4821}
4822
4823void RoutineOp::addSeq(MLIRContext *context,
4824 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4825 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4826 effectiveDeviceTypes));
4827}
4828
4829void RoutineOp::addVector(MLIRContext *context,
4830 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4831 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4832 effectiveDeviceTypes));
4833}
4834
4835void RoutineOp::addWorker(MLIRContext *context,
4836 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4837 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4838 effectiveDeviceTypes));
4839}
4840
4841void RoutineOp::addGang(MLIRContext *context,
4842 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4843 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4844 effectiveDeviceTypes));
4845}
4846
4847void RoutineOp::addGang(MLIRContext *context,
4848 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4849 uint64_t val) {
4852
4853 if (getGangDimAttr())
4854 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4855 if (getGangDimDeviceTypeAttr())
4856 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4857
4858 assert(dimValues.size() == deviceTypes.size());
4859
4860 if (effectiveDeviceTypes.empty()) {
4861 dimValues.push_back(
4862 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4863 deviceTypes.push_back(
4864 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4865 } else {
4866 for (DeviceType dt : effectiveDeviceTypes) {
4867 dimValues.push_back(
4868 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4869 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4870 }
4871 }
4872 assert(dimValues.size() == deviceTypes.size());
4873
4874 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4875 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4876}
4877
4878void RoutineOp::addBindStrName(MLIRContext *context,
4879 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4880 mlir::StringAttr val) {
4881 unsigned before = getBindStrNameDeviceTypeAttr()
4882 ? getBindStrNameDeviceTypeAttr().size()
4883 : 0;
4884
4885 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4886 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4887 unsigned after = getBindStrNameDeviceTypeAttr().size();
4888
4890 if (getBindStrNameAttr())
4891 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4892 for (unsigned i = 0; i < after - before; ++i)
4893 vals.push_back(val);
4894
4895 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4896}
4897
4898void RoutineOp::addBindIDName(MLIRContext *context,
4899 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4900 mlir::SymbolRefAttr val) {
4901 unsigned before =
4902 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4903
4904 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4905 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4906 unsigned after = getBindIdNameDeviceTypeAttr().size();
4907
4909 if (getBindIdNameAttr())
4910 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4911 for (unsigned i = 0; i < after - before; ++i)
4912 vals.push_back(val);
4913
4914 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4915}
4916
4917//===----------------------------------------------------------------------===//
4918// InitOp
4919//===----------------------------------------------------------------------===//
4920
4921LogicalResult acc::InitOp::verify() {
4922 Operation *currOp = *this;
4923 while ((currOp = currOp->getParentOp()))
4924 if (isComputeOperation(currOp))
4925 return emitOpError("cannot be nested in a compute operation");
4926 return success();
4927}
4928
4929void acc::InitOp::addDeviceType(MLIRContext *context,
4930 mlir::acc::DeviceType deviceType) {
4932 if (getDeviceTypesAttr())
4933 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4934
4935 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4936 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4937}
4938
4939//===----------------------------------------------------------------------===//
4940// ShutdownOp
4941//===----------------------------------------------------------------------===//
4942
4943LogicalResult acc::ShutdownOp::verify() {
4944 Operation *currOp = *this;
4945 while ((currOp = currOp->getParentOp()))
4946 if (isComputeOperation(currOp))
4947 return emitOpError("cannot be nested in a compute operation");
4948 return success();
4949}
4950
4951void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4952 mlir::acc::DeviceType deviceType) {
4954 if (getDeviceTypesAttr())
4955 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4956
4957 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4958 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4959}
4960
4961//===----------------------------------------------------------------------===//
4962// SetOp
4963//===----------------------------------------------------------------------===//
4964
4965LogicalResult acc::SetOp::verify() {
4966 Operation *currOp = *this;
4967 while ((currOp = currOp->getParentOp()))
4968 if (isComputeOperation(currOp))
4969 return emitOpError("cannot be nested in a compute operation");
4970 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4971 return emitOpError("at least one default_async, device_num, or device_type "
4972 "operand must appear");
4973 return success();
4974}
4975
4976//===----------------------------------------------------------------------===//
4977// UpdateOp
4978//===----------------------------------------------------------------------===//
4979
4980LogicalResult acc::UpdateOp::verify() {
4981 // At least one of host or device should have a value.
4982 if (getDataClauseOperands().empty())
4983 return emitError("at least one value must be present in dataOperands");
4984
4986 getAsyncOperandsDeviceTypeAttr(),
4987 "async")))
4988 return failure();
4989
4991 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4992 getWaitOperandsDeviceTypeAttr(), "wait")))
4993 return failure();
4994
4996 return failure();
4997
4998 for (mlir::Value operand : getDataClauseOperands())
4999 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
5000 operand.getDefiningOp()))
5001 return emitError("expect data entry/exit operation or acc.getdeviceptr "
5002 "as defining op");
5003
5004 return success();
5005}
5006
5007unsigned UpdateOp::getNumDataOperands() {
5008 return getDataClauseOperands().size();
5009}
5010
5011Value UpdateOp::getDataOperand(unsigned i) {
5012 unsigned numOptional = getAsyncOperands().size();
5013 numOptional += getIfCond() ? 1 : 0;
5014 return getOperand(getWaitOperands().size() + numOptional + i);
5015}
5016
5017void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
5018 MLIRContext *context) {
5019 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
5020}
5021
5022bool UpdateOp::hasAsyncOnly() {
5023 return hasAsyncOnly(mlir::acc::DeviceType::None);
5024}
5025
5026bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
5027 return hasDeviceType(getAsyncOnly(), deviceType);
5028}
5029
5030mlir::Value UpdateOp::getAsyncValue() {
5031 return getAsyncValue(mlir::acc::DeviceType::None);
5032}
5033
5034mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
5036 return {};
5037
5038 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
5039 return getAsyncOperands()[*pos];
5040
5041 return {};
5042}
5043
5044bool UpdateOp::hasWaitOnly() {
5045 return hasWaitOnly(mlir::acc::DeviceType::None);
5046}
5047
5048bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
5049 return hasDeviceType(getWaitOnly(), deviceType);
5050}
5051
5052mlir::Operation::operand_range UpdateOp::getWaitValues() {
5053 return getWaitValues(mlir::acc::DeviceType::None);
5054}
5055
5057UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
5059 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
5060 getHasWaitDevnum(), deviceType);
5061}
5062
5063mlir::Value UpdateOp::getWaitDevnum() {
5064 return getWaitDevnum(mlir::acc::DeviceType::None);
5065}
5066
5067mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
5068 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
5069 getWaitOperandsSegments(), getHasWaitDevnum(),
5070 deviceType);
5071}
5072
5073void UpdateOp::addAsyncOnly(MLIRContext *context,
5074 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5075 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
5076 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
5077}
5078
5079void UpdateOp::addAsyncOperand(
5080 MLIRContext *context, mlir::Value newValue,
5081 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5082 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5083 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
5084 getAsyncOperandsMutable()));
5085}
5086
5087void UpdateOp::addWaitOnly(MLIRContext *context,
5088 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5089 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
5090 effectiveDeviceTypes));
5091}
5092
5093void UpdateOp::addWaitOperands(
5094 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
5095 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
5096
5098 if (getWaitOperandsSegments())
5099 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
5100
5101 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5102 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
5103 getWaitOperandsMutable(), segments));
5104 setWaitOperandsSegments(segments);
5105
5107 if (getHasWaitDevnumAttr())
5108 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5109 hasDevnums.insert(
5110 hasDevnums.end(),
5111 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
5112 mlir::BoolAttr::get(context, hasDevnum));
5113 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5114}
5115
5116//===----------------------------------------------------------------------===//
5117// WaitOp
5118//===----------------------------------------------------------------------===//
5119
5120LogicalResult acc::WaitOp::verify() {
5121 // The async attribute represent the async clause without value. Therefore the
5122 // attribute and operand cannot appear at the same time.
5123 if (getAsyncOperand() && getAsync())
5124 return emitError("async attribute cannot appear with asyncOperand");
5125
5126 if (getWaitDevnum() && getWaitOperands().empty())
5127 return emitError("wait_devnum cannot appear without waitOperands");
5128
5129 return success();
5130}
5131
5132#define GET_OP_CLASSES
5133#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5134
5135#define GET_ATTRDEF_CLASSES
5136#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5137
5138#define GET_TYPEDEF_CLASSES
5139#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5140
5141//===----------------------------------------------------------------------===//
5142// acc dialect utilities
5143//===----------------------------------------------------------------------===//
5144
5147 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
5149 accDataClauseOp)
5150 .Case<ACC_DATA_ENTRY_OPS>(
5151 [&](auto entry) { return entry.getVarPtr(); })
5152 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5153 [&](auto exit) { return exit.getVarPtr(); })
5154 .Default([&](mlir::Operation *) {
5156 })};
5157 return varPtr;
5158}
5159
5161 auto varPtr{
5163 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
5164 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5165 return varPtr;
5166}
5167
5169 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
5170 .Case<ACC_DATA_ENTRY_OPS>(
5171 [&](auto entry) { return entry.getVarType(); })
5172 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5173 [&](auto exit) { return exit.getVarType(); })
5174 .Default([&](mlir::Operation *) { return mlir::Type(); })};
5175 return varType;
5176}
5177
5180 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
5182 accDataClauseOp)
5183 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5184 [&](auto dataClause) { return dataClause.getAccPtr(); })
5185 .Default([&](mlir::Operation *) {
5187 })};
5188 return accPtr;
5189}
5190
5192 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
5194 [&](auto dataClause) { return dataClause.getAccVar(); })
5195 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5196 return accPtr;
5197}
5198
5200 auto varPtrPtr{
5202 .Case<ACC_DATA_ENTRY_OPS>(
5203 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
5204 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5205 return varPtrPtr;
5206}
5207
5212 accDataClauseOp)
5213 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5215 dataClause.getBounds().begin(), dataClause.getBounds().end());
5216 })
5217 .Default([&](mlir::Operation *) {
5219 })};
5220 return bounds;
5221}
5222
5226 accDataClauseOp)
5227 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5229 dataClause.getAsyncOperands().begin(),
5230 dataClause.getAsyncOperands().end());
5231 })
5232 .Default([&](mlir::Operation *) {
5234 });
5235}
5236
5237mlir::ArrayAttr
5240 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5241 return dataClause.getAsyncOperandsDeviceTypeAttr();
5242 })
5243 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5244}
5245
5246mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
5249 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
5250 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5251}
5252
5253std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
5254 auto name{
5256 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
5257 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
5258 return {};
5259 })};
5260 return name;
5261}
5262
5263std::optional<mlir::acc::DataClause>
5265 auto dataClause{
5267 accDataEntryOp)
5268 .Case<ACC_DATA_ENTRY_OPS>(
5269 [&](auto entry) { return entry.getDataClause(); })
5270 .Default([&](mlir::Operation *) { return std::nullopt; })};
5271 return dataClause;
5272}
5273
5275 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
5276 .Case<ACC_DATA_ENTRY_OPS>(
5277 [&](auto entry) { return entry.getImplicit(); })
5278 .Default([&](mlir::Operation *) { return false; })};
5279 return implicit;
5280}
5281
5283 auto dataOperands{
5286 [&](auto entry) { return entry.getDataClauseOperands(); })
5287 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
5288 return dataOperands;
5289}
5290
5293 auto dataOperands{
5296 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
5297 .Default([&](mlir::Operation *) { return nullptr; })};
5298 return dataOperands;
5299}
5300
5301mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
5302 auto recipe{
5304 .Case<ACC_DATA_ENTRY_OPS>(
5305 [&](auto entry) { return entry.getRecipeAttr(); })
5306 .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
5307 return recipe;
5308}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4677
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1492
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:3428
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:2038
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4532
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
Definition OpenACC.cpp:829
static bool isComputeOperation(Operation *op)
Definition OpenACC.cpp:1506
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:609
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2534
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
Definition OpenACC.cpp:822
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:762
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:593
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:731
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2545
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2450
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:535
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4734
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:3237
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2785
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:3444
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:4439
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:669
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2739
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:553
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:651
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:685
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1247
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3812
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1993
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2581
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:2120
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1257
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:677
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:740
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:564
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:577
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2320
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:422
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:716
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3843
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:433
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4707
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4616
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2433
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2608
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2724
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2387
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
Definition OpenACC.cpp:1235
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2700
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:803
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:3256
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1773
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2769
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:2364
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
Definition OpenACC.cpp:695
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
Definition OpenACC.cpp:2007
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2681
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:539
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:3383
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2619
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:774
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:629
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:2048
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4498
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2370
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2805
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4586
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
false
Parses a map_entries map type from a string format back into its numeric value.
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:156
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:69
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:46
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:54
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5191
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5160
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5179
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5264
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:5292
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:5209
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:5282
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:5253
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:5274
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
Definition OpenACC.cpp:5301
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:5224
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
Definition OpenACC.h:167
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:5199
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:204
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5246
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5168
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5146
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5238
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.