MLIR 23.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/IR/Matchers.h"
22#include "mlir/IR/SymbolTable.h"
23#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
28#include <variant>
29
30using namespace mlir;
31using namespace acc;
32
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
38
39namespace {
40
41static bool isScalarLikeType(Type type) {
42 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
43}
44
45/// Helper function to attach the `VarName` attribute to an operation
46/// if a variable name is provided.
47static void attachVarNameAttr(Operation *op, OpBuilder &builder,
48 StringRef varName) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
51 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
52 }
53}
54
55template <typename T>
56struct MemRefPointerLikeModel
57 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
58 Type getElementType(Type pointer) const {
59 return cast<T>(pointer).getElementType();
60 }
61
62 mlir::acc::VariableTypeCategory
63 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
64 Type varType) const {
65 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
67 }
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
70 // This memref is unranked - aka it could have any rank, including a
71 // rank of 0 which could mean scalar. For now, return uncategorized.
72 return mlir::acc::VariableTypeCategory::uncategorized;
73 }
74
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
78 }
79 // Zero-rank non-scalar - need further analysis to determine the type
80 // category. For now, return uncategorized.
81 return mlir::acc::VariableTypeCategory::uncategorized;
82 }
83
84 // It has a rank - must be an array.
85 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
87 }
88
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree) const {
92 auto memrefTy = cast<MemRefType>(pointer);
93
94 // Check if this is a static memref (all dimensions are known) - if yes
95 // then we can generate an alloca operation.
96 if (memrefTy.hasStaticShape()) {
97 needsFree = false; // alloca doesn't need deallocation
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
101 }
102
103 // For dynamic memrefs, extract sizes from the original variable if
104 // provided. Otherwise they cannot be handled.
105 if (originalVar && originalVar.getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
110 // Extract the size of dimension i from the original variable
111 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
112 auto dimSize =
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
115 }
116 // Note: We only add dynamic sizes to the dynamicSizes array
117 // Static dimensions are handled automatically by AllocOp
118 }
119 needsFree = true; // alloc needs deallocation
120 auto allocOp =
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
124 }
125
126 // TODO: Unranked not yet supported.
127 return {};
128 }
129
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
131 TypedValue<PointerLikeType> varToFree, Value allocRes,
132 Type varType) const {
133 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
134 // Use allocRes if provided to determine the allocation type
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
136
137 // Walk through casts to find the original allocation
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc = nullptr;
140
141 // Follow the chain of operations to find the original allocation
142 // even if a casted result is provided.
143 while (currentValue) {
144 if (auto *definingOp = currentValue.getDefiningOp()) {
145 // Check if this is an allocation operation
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
148 break;
149 }
150
151 // Check if this is a cast operation we can look through
152 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
154 continue;
155 }
156
157 // Check for other cast-like operations
158 if (auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
161 continue;
162 }
163
164 // If we can't look through this operation, stop
165 break;
166 }
167 // This is a block argument or similar - can't trace further.
168 break;
169 }
170
171 if (originalAlloc) {
172 if (isa<memref::AllocaOp>(originalAlloc)) {
173 // This is an alloca - no dealloc needed, but return true (success)
174 return true;
175 }
176 if (isa<memref::AllocOp>(originalAlloc)) {
177 // This is an alloc - generate dealloc on varToFree
178 memref::DeallocOp::create(builder, loc, memrefValue);
179 return true;
180 }
181 }
182 }
183
184 return false;
185 }
186
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
188 TypedValue<PointerLikeType> destination,
189 TypedValue<PointerLikeType> source, Type varType) const {
190 // Generate a copy operation between two memrefs
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
193
194 // As per memref documentation, source and destination must have same
195 // element type and shape in order to be compatible. We do not want to fail
196 // with an IR verification error - thus check that before generating the
197 // copy operation.
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
203 return true;
204 }
205
206 return false;
207 }
208
209 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType) const {
212 // Load from a memref - only valid for scalar memrefs (rank 0).
213 // This is because the address computation for memrefs is part of the load
214 // (and not computed separately), but the API does not have arguments for
215 // indexing.
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
217 if (!memrefValue)
218 return {};
219
220 auto memrefTy = memrefValue.getType();
221
222 // Only load from scalar memrefs (rank 0)
223 if (memrefTy.getRank() != 0)
224 return {};
225
226 return memref::LoadOp::create(builder, loc, memrefValue);
227 }
228
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
230 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
231 // Store to a memref - only valid for scalar memrefs (rank 0)
232 // This is because the address computation for memrefs is part of the store
233 // (and not computed separately), but the API does not have arguments for
234 // indexing.
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
236 if (!memrefValue)
237 return false;
238
239 auto memrefTy = memrefValue.getType();
240
241 // Only store to scalar memrefs (rank 0)
242 if (memrefTy.getRank() != 0)
243 return false;
244
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
246 return true;
247 }
248
249 bool isDeviceData(Type pointer, Value var) const {
250 auto memrefTy = cast<T>(pointer);
251 Attribute memSpace = memrefTy.getMemorySpace();
252 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
253 }
254};
255
256struct LLVMPointerPointerLikeModel
257 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
258 LLVM::LLVMPointerType> {
259 Type getElementType(Type pointer) const { return Type(); }
260
261 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
263 Type valueType) const {
264 // For LLVM pointers, we need the valueType to determine what to load
265 if (!valueType)
266 return {};
267
268 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
269 }
270
271 bool genStore(Type pointer, OpBuilder &builder, Location loc,
272 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
273 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
274 return true;
275 }
276};
277
278struct MemrefAddressOfGlobalModel
279 : public AddressOfGlobalOpInterface::ExternalModel<
280 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
281 SymbolRefAttr getSymbol(Operation *op) const {
282 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
283 return getGlobalOp.getNameAttr();
284 }
285};
286
287struct MemrefGlobalVariableModel
288 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
289 memref::GlobalOp> {
290 bool isConstant(Operation *op) const {
291 auto globalOp = cast<memref::GlobalOp>(op);
292 return globalOp.getConstant();
293 }
294
295 Region *getInitRegion(Operation *op) const {
296 // GlobalOp uses attributes for initialization, not regions
297 return nullptr;
298 }
299
300 bool isDeviceData(Operation *op) const {
301 auto globalOp = cast<memref::GlobalOp>(op);
302 Attribute memSpace = globalOp.getType().getMemorySpace();
303 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
304 }
305};
306
307struct GPULaunchOffloadRegionModel
308 : public acc::OffloadRegionOpInterface::ExternalModel<
309 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
310 mlir::Region &getOffloadRegion(mlir::Operation *op) const {
311 return cast<gpu::LaunchOp>(op).getBody();
312 }
313};
314
315/// Helper function for any of the times we need to modify an ArrayAttr based on
316/// a device type list. Returns a new ArrayAttr with all of the
317/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
318/// list is empty).
319mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
320 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
321 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
323 if (existingDeviceTypes)
324 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
325
326 if (newDeviceTypes.empty())
327 deviceTypes.push_back(
328 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
329
330 for (DeviceType dt : newDeviceTypes)
331 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
332
333 return mlir::ArrayAttr::get(context, deviceTypes);
334}
335
336/// Helper function for any of the times we need to add operands that are
337/// affected by a device type list. Returns a new ArrayAttr with all of the
338/// existingDeviceTypes, plus the effective new ones (or an added none, if the
339/// new list is empty). Additionally, adds the arguments to the argCollection
340/// the correct number of times. This will also update a 'segments' array, even
341/// if it won't be used.
342mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
343 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
344 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
345 mlir::MutableOperandRange argCollection,
346 llvm::SmallVector<int32_t> &segments) {
348 if (existingDeviceTypes)
349 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
350
351 if (newDeviceTypes.empty()) {
352 argCollection.append(arguments);
353 segments.push_back(arguments.size());
354 deviceTypes.push_back(
355 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
356 }
357
358 for (DeviceType dt : newDeviceTypes) {
359 argCollection.append(arguments);
360 segments.push_back(arguments.size());
361 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
362 }
363
364 return mlir::ArrayAttr::get(context, deviceTypes);
365}
366
367/// Overload for when the 'segments' aren't needed.
368mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
369 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
370 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
371 mlir::MutableOperandRange argCollection) {
373 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
374 newDeviceTypes, arguments,
375 argCollection, segments);
376}
377} // namespace
378
379//===----------------------------------------------------------------------===//
380// OpenACC operations
381//===----------------------------------------------------------------------===//
382
383void OpenACCDialect::initialize() {
384 addOperations<
385#define GET_OP_LIST
386#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
387 >();
388 addAttributes<
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
391 >();
392 addTypes<
393#define GET_TYPEDEF_LIST
394#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
395 >();
396
397 // By attaching interfaces here, we make the OpenACC dialect dependent on
398 // the other dialects. This is probably better than having dialects like LLVM
399 // and memref be dependent on OpenACC.
400 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
401 *getContext());
402 UnrankedMemRefType::attachInterface<
403 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
404 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
405 *getContext());
406
407 // Attach operation interfaces
408 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
409 *getContext());
410 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
411 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext());
412}
413
414//===----------------------------------------------------------------------===//
415// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
416// acc.kernel_environment / acc.data / acc.host_data / acc.loop
417//===----------------------------------------------------------------------===//
418
419/// Generic helper for single-region OpenACC ops that execute their body once
420/// and then return to the parent operation with their results (if any).
421static void
423 RegionBranchPoint point,
425 if (point.isParent()) {
426 regions.push_back(RegionSuccessor(&region));
427 return;
428 }
429
430 regions.push_back(RegionSuccessor::parent());
431}
432
434 RegionSuccessor successor) {
435 return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
436}
437
438void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
440 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
441 regions);
442}
443
444ValueRange KernelsOp::getSuccessorInputs(RegionSuccessor successor) {
445 return getSingleRegionSuccessorInputs(getOperation(), successor);
446}
447
448void ParallelOp::getSuccessorRegions(
450 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
451 regions);
452}
453
454ValueRange ParallelOp::getSuccessorInputs(RegionSuccessor successor) {
455 return getSingleRegionSuccessorInputs(getOperation(), successor);
456}
457
458void SerialOp::getSuccessorRegions(RegionBranchPoint point,
460 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
461 regions);
462}
463
464ValueRange SerialOp::getSuccessorInputs(RegionSuccessor successor) {
465 return getSingleRegionSuccessorInputs(getOperation(), successor);
466}
467
468void KernelEnvironmentOp::getSuccessorRegions(
470 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
471 regions);
472}
473
474ValueRange KernelEnvironmentOp::getSuccessorInputs(RegionSuccessor successor) {
475 return getSingleRegionSuccessorInputs(getOperation(), successor);
476}
477
478void DataOp::getSuccessorRegions(RegionBranchPoint point,
480 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
481 regions);
482}
483
484ValueRange DataOp::getSuccessorInputs(RegionSuccessor successor) {
485 return getSingleRegionSuccessorInputs(getOperation(), successor);
486}
487
488void HostDataOp::getSuccessorRegions(
490 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
491 regions);
492}
493
494ValueRange HostDataOp::getSuccessorInputs(RegionSuccessor successor) {
495 return getSingleRegionSuccessorInputs(getOperation(), successor);
496}
497
498void LoopOp::getSuccessorRegions(RegionBranchPoint point,
500 // Unstructured loops: the body may contain arbitrary CFG and early exits.
501 // At the RegionBranch level, only model entry into the body and exit to the
502 // parent; any backedges are represented inside the region CFG.
503 if (getUnstructured()) {
504 if (point.isParent()) {
505 regions.push_back(RegionSuccessor(&getRegion()));
506 return;
507 }
508 regions.push_back(RegionSuccessor::parent());
509 return;
510 }
511
512 // Structured loops: model a loop-shaped region graph similar to scf.for.
513 regions.push_back(RegionSuccessor(&getRegion()));
514 regions.push_back(RegionSuccessor::parent());
515}
516
517ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) {
518 return getSingleRegionSuccessorInputs(getOperation(), successor);
519}
520
521//===----------------------------------------------------------------------===//
522// RegionBranchTerminatorOpInterface
523//===----------------------------------------------------------------------===//
524
526TerminatorOp::getMutableSuccessorOperands(RegionSuccessor /*point*/) {
527 // `acc.terminator` does not forward operands.
528 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
529}
530
531//===----------------------------------------------------------------------===//
532// device_type support helpers
533//===----------------------------------------------------------------------===//
534
535static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
536 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
537}
538
539static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
540 mlir::acc::DeviceType deviceType) {
541 if (!hasDeviceTypeValues(arrayAttr))
542 return false;
543
544 for (auto attr : *arrayAttr) {
545 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
546 if (deviceTypeAttr.getValue() == deviceType)
547 return true;
548 }
549
550 return false;
551}
552
554 std::optional<mlir::ArrayAttr> deviceTypes) {
555 if (!hasDeviceTypeValues(deviceTypes))
556 return;
557
558 p << "[";
559 llvm::interleaveComma(*deviceTypes, p,
560 [&](mlir::Attribute attr) { p << attr; });
561 p << "]";
562}
563
564static std::optional<unsigned> findSegment(ArrayAttr segments,
565 mlir::acc::DeviceType deviceType) {
566 unsigned segmentIdx = 0;
567 for (auto attr : segments) {
568 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
569 if (deviceTypeAttr.getValue() == deviceType)
570 return std::make_optional(segmentIdx);
571 ++segmentIdx;
572 }
573 return std::nullopt;
574}
575
577getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
579 std::optional<llvm::ArrayRef<int32_t>> segments,
580 mlir::acc::DeviceType deviceType) {
581 if (!arrayAttr)
582 return range.take_front(0);
583 if (auto pos = findSegment(*arrayAttr, deviceType)) {
584 int32_t nbOperandsBefore = 0;
585 for (unsigned i = 0; i < *pos; ++i)
586 nbOperandsBefore += (*segments)[i];
587 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
588 }
589 return range.take_front(0);
590}
591
592static mlir::Value
593getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
595 std::optional<llvm::ArrayRef<int32_t>> segments,
596 std::optional<mlir::ArrayAttr> hasWaitDevnum,
597 mlir::acc::DeviceType deviceType) {
598 if (!hasDeviceTypeValues(deviceTypeAttr))
599 return {};
600 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
601 if (hasWaitDevnum->getValue()[*pos])
602 return getValuesFromSegments(deviceTypeAttr, operands, segments,
603 deviceType)
604 .front();
605 return {};
606}
607
609getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
611 std::optional<llvm::ArrayRef<int32_t>> segments,
612 std::optional<mlir::ArrayAttr> hasWaitDevnum,
613 mlir::acc::DeviceType deviceType) {
614 auto range =
615 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
616 if (range.empty())
617 return range;
618 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
619 if (hasWaitDevnum && *hasWaitDevnum) {
620 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
621 if (boolAttr.getValue())
622 return range.drop_front(1); // first value is devnum
623 }
624 }
625 return range;
626}
627
628template <typename Op>
629static LogicalResult checkWaitAndAsyncConflict(Op op) {
630 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
631 ++dtypeInt) {
632 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
633
634 // The asyncOnly attribute represent the async clause without value.
635 // Therefore the attribute and operand cannot appear at the same time.
636 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
637 op.hasAsyncOnly(dtype))
638 return op.emitError(
639 "asyncOnly attribute cannot appear with asyncOperand");
640
641 // The wait attribute represent the wait clause without values. Therefore
642 // the attribute and operands cannot appear at the same time.
643 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
644 op.hasWaitOnly(dtype))
645 return op.emitError("wait attribute cannot appear with waitOperands");
646 }
647 return success();
648}
649
650template <typename Op>
651static LogicalResult checkVarAndVarType(Op op) {
652 if (!op.getVar())
653 return op.emitError("must have var operand");
654
655 // A variable must have a type that is either pointer-like or mappable.
656 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
657 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
658 return op.emitError("var must be mappable or pointer-like");
659
660 // When it is a pointer-like type, the varType must capture the target type.
661 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
662 op.getVarType() == op.getVar().getType())
663 return op.emitError("varType must capture the element type of var");
664
665 return success();
666}
667
668template <typename Op>
669static LogicalResult checkVarAndAccVar(Op op) {
670 if (op.getVar().getType() != op.getAccVar().getType())
671 return op.emitError("input and output types must match");
672
673 return success();
674}
675
676template <typename Op>
677static LogicalResult checkNoModifier(Op op) {
678 if (op.getModifiers() != acc::DataClauseModifier::none)
679 return op.emitError("no data clause modifiers are allowed");
680 return success();
681}
682
683template <typename Op>
684static LogicalResult
685checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
686 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
687 return op.emitError(
688 "invalid data clause modifiers: " +
689 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
690
691 return success();
692}
693
694template <typename OpT, typename RecipeOpT>
695static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
696 // Mappable types do not need a recipe because it is possible to generate one
697 // from its API. Reject reductions though because no API is available for them
698 // at this time.
699 if (mlir::acc::isMappableType(op.getVar().getType()) &&
700 !std::is_same_v<OpT, acc::ReductionOp>)
701 return success();
702
703 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
704 if (!operandRecipe)
705 return op->emitOpError() << "recipe expected for " << operandName;
706
707 auto decl =
709 if (!decl)
710 return op->emitOpError()
711 << "expected symbol reference " << operandRecipe << " to point to a "
712 << operandName << " declaration";
713 return success();
714}
715
716static ParseResult parseVar(mlir::OpAsmParser &parser,
718 // Either `var` or `varPtr` keyword is required.
719 if (failed(parser.parseOptionalKeyword("varPtr"))) {
720 if (failed(parser.parseKeyword("var")))
721 return failure();
722 }
723 if (failed(parser.parseLParen()))
724 return failure();
725 if (failed(parser.parseOperand(var)))
726 return failure();
727
728 return success();
729}
730
732 mlir::Value var) {
733 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
734 p << "varPtr(";
735 else
736 p << "var(";
737 p.printOperand(var);
738}
739
740static ParseResult parseAccVar(mlir::OpAsmParser &parser,
742 mlir::Type &accVarType) {
743 // Either `accVar` or `accPtr` keyword is required.
744 if (failed(parser.parseOptionalKeyword("accPtr"))) {
745 if (failed(parser.parseKeyword("accVar")))
746 return failure();
747 }
748 if (failed(parser.parseLParen()))
749 return failure();
750 if (failed(parser.parseOperand(var)))
751 return failure();
752 if (failed(parser.parseColon()))
753 return failure();
754 if (failed(parser.parseType(accVarType)))
755 return failure();
756 if (failed(parser.parseRParen()))
757 return failure();
758
759 return success();
760}
761
763 mlir::Value accVar, mlir::Type accVarType) {
764 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
765 p << "accPtr(";
766 else
767 p << "accVar(";
768 p.printOperand(accVar);
769 p << " : ";
770 p.printType(accVarType);
771 p << ")";
772}
773
774static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
775 mlir::Type &varPtrType,
776 mlir::TypeAttr &varTypeAttr) {
777 if (failed(parser.parseType(varPtrType)))
778 return failure();
779 if (failed(parser.parseRParen()))
780 return failure();
781
782 if (succeeded(parser.parseOptionalKeyword("varType"))) {
783 if (failed(parser.parseLParen()))
784 return failure();
785 mlir::Type varType;
786 if (failed(parser.parseType(varType)))
787 return failure();
788 varTypeAttr = mlir::TypeAttr::get(varType);
789 if (failed(parser.parseRParen()))
790 return failure();
791 } else {
792 // Set `varType` from the element type of the type of `varPtr`.
793 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
794 varTypeAttr = mlir::TypeAttr::get(
795 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
796 else
797 varTypeAttr = mlir::TypeAttr::get(varPtrType);
798 }
799
800 return success();
801}
802
804 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
805 p.printType(varPtrType);
806 p << ")";
807
808 // Print the `varType` only if it differs from the element type of
809 // `varPtr`'s type.
810 mlir::Type varType = varTypeAttr.getValue();
811 mlir::Type typeToCheckAgainst =
812 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
813 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
814 : varPtrType;
815 if (typeToCheckAgainst != varType) {
816 p << " varType(";
817 p.printType(varType);
818 p << ")";
819 }
820}
821
822static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
823 mlir::SymbolRefAttr &recipeAttr) {
824 if (failed(parser.parseAttribute(recipeAttr)))
825 return failure();
826 return success();
827}
828
830 mlir::SymbolRefAttr recipeAttr) {
831 p << recipeAttr;
832}
833
834//===----------------------------------------------------------------------===//
835// DataBoundsOp
836//===----------------------------------------------------------------------===//
837LogicalResult acc::DataBoundsOp::verify() {
838 auto extent = getExtent();
839 auto upperbound = getUpperbound();
840 if (!extent && !upperbound)
841 return emitError("expected extent or upperbound.");
842 return success();
843}
844
845//===----------------------------------------------------------------------===//
846// PrivateOp
847//===----------------------------------------------------------------------===//
848LogicalResult acc::PrivateOp::verify() {
849 if (getDataClause() != acc::DataClause::acc_private)
850 return emitError(
851 "data clause associated with private operation must match its intent");
852 if (failed(checkVarAndVarType(*this)))
853 return failure();
854 if (failed(checkNoModifier(*this)))
855 return failure();
856 if (failed(
858 return failure();
859 return success();
860}
861
862//===----------------------------------------------------------------------===//
863// FirstprivateOp
864//===----------------------------------------------------------------------===//
865LogicalResult acc::FirstprivateOp::verify() {
866 if (getDataClause() != acc::DataClause::acc_firstprivate)
867 return emitError("data clause associated with firstprivate operation must "
868 "match its intent");
869 if (failed(checkVarAndVarType(*this)))
870 return failure();
871 if (failed(checkNoModifier(*this)))
872 return failure();
874 *this, "firstprivate")))
875 return failure();
876 return success();
877}
878
879//===----------------------------------------------------------------------===//
880// FirstprivateMapInitialOp
881//===----------------------------------------------------------------------===//
882LogicalResult acc::FirstprivateMapInitialOp::verify() {
883 if (getDataClause() != acc::DataClause::acc_firstprivate)
884 return emitError("data clause associated with firstprivate operation must "
885 "match its intent");
886 if (failed(checkVarAndVarType(*this)))
887 return failure();
888 if (failed(checkNoModifier(*this)))
889 return failure();
890 return success();
891}
892
893//===----------------------------------------------------------------------===//
894// ReductionOp
895//===----------------------------------------------------------------------===//
896LogicalResult acc::ReductionOp::verify() {
897 if (getDataClause() != acc::DataClause::acc_reduction)
898 return emitError("data clause associated with reduction operation must "
899 "match its intent");
900 if (failed(checkVarAndVarType(*this)))
901 return failure();
902 if (failed(checkNoModifier(*this)))
903 return failure();
905 *this, "reduction")))
906 return failure();
907 return success();
908}
909
910//===----------------------------------------------------------------------===//
911// DevicePtrOp
912//===----------------------------------------------------------------------===//
913LogicalResult acc::DevicePtrOp::verify() {
914 if (getDataClause() != acc::DataClause::acc_deviceptr)
915 return emitError("data clause associated with deviceptr operation must "
916 "match its intent");
917 if (failed(checkVarAndVarType(*this)))
918 return failure();
919 if (failed(checkVarAndAccVar(*this)))
920 return failure();
921 if (failed(checkNoModifier(*this)))
922 return failure();
923 return success();
924}
925
926//===----------------------------------------------------------------------===//
927// PresentOp
928//===----------------------------------------------------------------------===//
929LogicalResult acc::PresentOp::verify() {
930 if (getDataClause() != acc::DataClause::acc_present)
931 return emitError(
932 "data clause associated with present operation must match its intent");
933 if (failed(checkVarAndVarType(*this)))
934 return failure();
935 if (failed(checkVarAndAccVar(*this)))
936 return failure();
937 if (failed(checkNoModifier(*this)))
938 return failure();
939 return success();
940}
941
942//===----------------------------------------------------------------------===//
943// CopyinOp
944//===----------------------------------------------------------------------===//
945LogicalResult acc::CopyinOp::verify() {
946 // Test for all clauses this operation can be decomposed from:
947 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
948 getDataClause() != acc::DataClause::acc_copyin_readonly &&
949 getDataClause() != acc::DataClause::acc_copy &&
950 getDataClause() != acc::DataClause::acc_reduction)
951 return emitError(
952 "data clause associated with copyin operation must match its intent"
953 " or specify original clause this operation was decomposed from");
954 if (failed(checkVarAndVarType(*this)))
955 return failure();
956 if (failed(checkVarAndAccVar(*this)))
957 return failure();
958 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
959 acc::DataClauseModifier::always |
960 acc::DataClauseModifier::capture)))
961 return failure();
962 return success();
963}
964
965bool acc::CopyinOp::isCopyinReadonly() {
966 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
967 acc::bitEnumContainsAny(getModifiers(),
968 acc::DataClauseModifier::readonly);
969}
970
971//===----------------------------------------------------------------------===//
972// CreateOp
973//===----------------------------------------------------------------------===//
974LogicalResult acc::CreateOp::verify() {
975 // Test for all clauses this operation can be decomposed from:
976 if (getDataClause() != acc::DataClause::acc_create &&
977 getDataClause() != acc::DataClause::acc_create_zero &&
978 getDataClause() != acc::DataClause::acc_copyout &&
979 getDataClause() != acc::DataClause::acc_copyout_zero)
980 return emitError(
981 "data clause associated with create operation must match its intent"
982 " or specify original clause this operation was decomposed from");
983 if (failed(checkVarAndVarType(*this)))
984 return failure();
985 if (failed(checkVarAndAccVar(*this)))
986 return failure();
987 // this op is the entry part of copyout, so it also needs to allow all
988 // modifiers allowed on copyout.
989 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
990 acc::DataClauseModifier::always |
991 acc::DataClauseModifier::capture)))
992 return failure();
993 return success();
994}
995
996bool acc::CreateOp::isCreateZero() {
997 // The zero modifier is encoded in the data clause.
998 return getDataClause() == acc::DataClause::acc_create_zero ||
999 getDataClause() == acc::DataClause::acc_copyout_zero ||
1000 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1001}
1002
1003//===----------------------------------------------------------------------===//
1004// NoCreateOp
1005//===----------------------------------------------------------------------===//
1006LogicalResult acc::NoCreateOp::verify() {
1007 if (getDataClause() != acc::DataClause::acc_no_create)
1008 return emitError("data clause associated with no_create operation must "
1009 "match its intent");
1010 if (failed(checkVarAndVarType(*this)))
1011 return failure();
1012 if (failed(checkVarAndAccVar(*this)))
1013 return failure();
1014 if (failed(checkNoModifier(*this)))
1015 return failure();
1016 return success();
1017}
1018
1019//===----------------------------------------------------------------------===//
1020// AttachOp
1021//===----------------------------------------------------------------------===//
1022LogicalResult acc::AttachOp::verify() {
1023 if (getDataClause() != acc::DataClause::acc_attach)
1024 return emitError(
1025 "data clause associated with attach operation must match its intent");
1026 if (failed(checkVarAndVarType(*this)))
1027 return failure();
1028 if (failed(checkVarAndAccVar(*this)))
1029 return failure();
1030 if (failed(checkNoModifier(*this)))
1031 return failure();
1032 return success();
1033}
1034
1035//===----------------------------------------------------------------------===//
1036// DeclareDeviceResidentOp
1037//===----------------------------------------------------------------------===//
1038
1039LogicalResult acc::DeclareDeviceResidentOp::verify() {
1040 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
1041 return emitError("data clause associated with device_resident operation "
1042 "must match its intent");
1043 if (failed(checkVarAndVarType(*this)))
1044 return failure();
1045 if (failed(checkVarAndAccVar(*this)))
1046 return failure();
1047 if (failed(checkNoModifier(*this)))
1048 return failure();
1049 return success();
1050}
1051
1052//===----------------------------------------------------------------------===//
1053// DeclareLinkOp
1054//===----------------------------------------------------------------------===//
1055
1056LogicalResult acc::DeclareLinkOp::verify() {
1057 if (getDataClause() != acc::DataClause::acc_declare_link)
1058 return emitError(
1059 "data clause associated with link operation must match its intent");
1060 if (failed(checkVarAndVarType(*this)))
1061 return failure();
1062 if (failed(checkVarAndAccVar(*this)))
1063 return failure();
1064 if (failed(checkNoModifier(*this)))
1065 return failure();
1066 return success();
1067}
1068
1069//===----------------------------------------------------------------------===//
1070// CopyoutOp
1071//===----------------------------------------------------------------------===//
1072LogicalResult acc::CopyoutOp::verify() {
1073 // Test for all clauses this operation can be decomposed from:
1074 if (getDataClause() != acc::DataClause::acc_copyout &&
1075 getDataClause() != acc::DataClause::acc_copyout_zero &&
1076 getDataClause() != acc::DataClause::acc_copy &&
1077 getDataClause() != acc::DataClause::acc_reduction)
1078 return emitError(
1079 "data clause associated with copyout operation must match its intent"
1080 " or specify original clause this operation was decomposed from");
1081 if (!getVar() || !getAccVar())
1082 return emitError("must have both host and device pointers");
1083 if (failed(checkVarAndVarType(*this)))
1084 return failure();
1085 if (failed(checkVarAndAccVar(*this)))
1086 return failure();
1087 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1088 acc::DataClauseModifier::always |
1089 acc::DataClauseModifier::capture)))
1090 return failure();
1091 return success();
1092}
1093
1094bool acc::CopyoutOp::isCopyoutZero() {
1095 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1096 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1097}
1098
1099//===----------------------------------------------------------------------===//
1100// DeleteOp
1101//===----------------------------------------------------------------------===//
1102LogicalResult acc::DeleteOp::verify() {
1103 // Test for all clauses this operation can be decomposed from:
1104 if (getDataClause() != acc::DataClause::acc_delete &&
1105 getDataClause() != acc::DataClause::acc_create &&
1106 getDataClause() != acc::DataClause::acc_create_zero &&
1107 getDataClause() != acc::DataClause::acc_copyin &&
1108 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1109 getDataClause() != acc::DataClause::acc_present &&
1110 getDataClause() != acc::DataClause::acc_no_create &&
1111 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1112 getDataClause() != acc::DataClause::acc_declare_link)
1113 return emitError(
1114 "data clause associated with delete operation must match its intent"
1115 " or specify original clause this operation was decomposed from");
1116 if (!getAccVar())
1117 return emitError("must have device pointer");
1118 // This op is the exit part of copyin and create - thus allow all modifiers
1119 // allowed on either case.
1120 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1121 acc::DataClauseModifier::readonly |
1122 acc::DataClauseModifier::always |
1123 acc::DataClauseModifier::capture)))
1124 return failure();
1125 return success();
1126}
1127
1128//===----------------------------------------------------------------------===//
1129// DetachOp
1130//===----------------------------------------------------------------------===//
1131LogicalResult acc::DetachOp::verify() {
1132 // Test for all clauses this operation can be decomposed from:
1133 if (getDataClause() != acc::DataClause::acc_detach &&
1134 getDataClause() != acc::DataClause::acc_attach)
1135 return emitError(
1136 "data clause associated with detach operation must match its intent"
1137 " or specify original clause this operation was decomposed from");
1138 if (!getAccVar())
1139 return emitError("must have device pointer");
1140 if (failed(checkNoModifier(*this)))
1141 return failure();
1142 return success();
1143}
1144
1145//===----------------------------------------------------------------------===//
1146// HostOp
1147//===----------------------------------------------------------------------===//
1148LogicalResult acc::UpdateHostOp::verify() {
1149 // Test for all clauses this operation can be decomposed from:
1150 if (getDataClause() != acc::DataClause::acc_update_host &&
1151 getDataClause() != acc::DataClause::acc_update_self)
1152 return emitError(
1153 "data clause associated with host operation must match its intent"
1154 " or specify original clause this operation was decomposed from");
1155 if (!getVar() || !getAccVar())
1156 return emitError("must have both host and device pointers");
1157 if (failed(checkVarAndVarType(*this)))
1158 return failure();
1159 if (failed(checkVarAndAccVar(*this)))
1160 return failure();
1161 if (failed(checkNoModifier(*this)))
1162 return failure();
1163 return success();
1164}
1165
1166//===----------------------------------------------------------------------===//
1167// DeviceOp
1168//===----------------------------------------------------------------------===//
1169LogicalResult acc::UpdateDeviceOp::verify() {
1170 // Test for all clauses this operation can be decomposed from:
1171 if (getDataClause() != acc::DataClause::acc_update_device)
1172 return emitError(
1173 "data clause associated with device operation must match its intent"
1174 " or specify original clause this operation was decomposed from");
1175 if (failed(checkVarAndVarType(*this)))
1176 return failure();
1177 if (failed(checkVarAndAccVar(*this)))
1178 return failure();
1179 if (failed(checkNoModifier(*this)))
1180 return failure();
1181 return success();
1182}
1183
1184//===----------------------------------------------------------------------===//
1185// UseDeviceOp
1186//===----------------------------------------------------------------------===//
1187LogicalResult acc::UseDeviceOp::verify() {
1188 // Test for all clauses this operation can be decomposed from:
1189 if (getDataClause() != acc::DataClause::acc_use_device)
1190 return emitError(
1191 "data clause associated with use_device operation must match its intent"
1192 " or specify original clause this operation was decomposed from");
1193 if (failed(checkVarAndVarType(*this)))
1194 return failure();
1195 if (failed(checkVarAndAccVar(*this)))
1196 return failure();
1197 if (failed(checkNoModifier(*this)))
1198 return failure();
1199 return success();
1200}
1201
1202//===----------------------------------------------------------------------===//
1203// CacheOp
1204//===----------------------------------------------------------------------===//
1205LogicalResult acc::CacheOp::verify() {
1206 // Test for all clauses this operation can be decomposed from:
1207 if (getDataClause() != acc::DataClause::acc_cache &&
1208 getDataClause() != acc::DataClause::acc_cache_readonly)
1209 return emitError(
1210 "data clause associated with cache operation must match its intent"
1211 " or specify original clause this operation was decomposed from");
1212 if (failed(checkVarAndVarType(*this)))
1213 return failure();
1214 if (failed(checkVarAndAccVar(*this)))
1215 return failure();
1216 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
1217 return failure();
1218 return success();
1219}
1220
1221bool acc::CacheOp::isCacheReadonly() {
1222 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1223 acc::bitEnumContainsAny(getModifiers(),
1224 acc::DataClauseModifier::readonly);
1225}
1226
1227template <typename StructureOp>
1228static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
1229 unsigned nRegions = 1) {
1230
1232 for (unsigned i = 0; i < nRegions; ++i)
1233 regions.push_back(state.addRegion());
1234
1235 for (Region *region : regions)
1236 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
1237 return failure();
1238
1239 return success();
1240}
1241
1243 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1244}
1245
1246namespace {
1247/// Pattern to remove operation without region that have constant false `ifCond`
1248/// and remove the condition from the operation if the `ifCond` is a true
1249/// constant.
1250template <typename OpTy>
1251struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1252 using OpRewritePattern<OpTy>::OpRewritePattern;
1253
1254 LogicalResult matchAndRewrite(OpTy op,
1255 PatternRewriter &rewriter) const override {
1256 // Early return if there is no condition.
1257 Value ifCond = op.getIfCond();
1258 if (!ifCond)
1259 return failure();
1260
1261 IntegerAttr constAttr;
1262 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1263 return failure();
1264 if (constAttr.getInt())
1265 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1266 else
1267 rewriter.eraseOp(op);
1268
1269 return success();
1270 }
1271};
1272
1273/// Replaces the given op with the contents of the given single-block region,
1274/// using the operands of the block terminator to replace operation results.
1275static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1276 Region &region, ValueRange blockArgs = {}) {
1277 assert(region.hasOneBlock() && "expected single-block region");
1278 Block *block = &region.front();
1279 Operation *terminator = block->getTerminator();
1280 ValueRange results = terminator->getOperands();
1281 rewriter.inlineBlockBefore(block, op, blockArgs);
1282 rewriter.replaceOp(op, results);
1283 rewriter.eraseOp(terminator);
1284}
1285
1286/// Pattern to remove operation with region that have constant false `ifCond`
1287/// and remove the condition from the operation if the `ifCond` is constant
1288/// true.
1289template <typename OpTy>
1290struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1291 using OpRewritePattern<OpTy>::OpRewritePattern;
1292
1293 LogicalResult matchAndRewrite(OpTy op,
1294 PatternRewriter &rewriter) const override {
1295 // Early return if there is no condition.
1296 Value ifCond = op.getIfCond();
1297 if (!ifCond)
1298 return failure();
1299
1300 IntegerAttr constAttr;
1301 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1302 return failure();
1303 if (constAttr.getInt())
1304 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1305 else
1306 replaceOpWithRegion(rewriter, op, op.getRegion());
1307
1308 return success();
1309 }
1310};
1311
1312/// Remove empty acc.kernel_environment operations. If the operation has wait
1313/// operands, create a acc.wait operation to preserve synchronization.
1314struct RemoveEmptyKernelEnvironment
1315 : public OpRewritePattern<acc::KernelEnvironmentOp> {
1316 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1317
1318 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1319 PatternRewriter &rewriter) const override {
1320 assert(op->getNumRegions() == 1 && "expected op to have one region");
1321
1322 Block &block = op.getRegion().front();
1323 if (!block.empty())
1324 return failure();
1325
1326 // Conservatively disable canonicalization of empty acc.kernel_environment
1327 // operations if the wait operands in the kernel_environment cannot be fully
1328 // represented by acc.wait operation.
1329
1330 // Disable canonicalization if device type is not the default
1331 if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1332 for (auto attr : deviceTypeAttr) {
1333 if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1334 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1335 return failure();
1336 }
1337 }
1338 }
1339
1340 // Disable canonicalization if any wait segment has a devnum
1341 if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1342 for (auto attr : hasDevnumAttr) {
1343 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1344 if (boolAttr.getValue())
1345 return failure();
1346 }
1347 }
1348 }
1349
1350 // Disable canonicalization if there are multiple wait segments
1351 if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1352 if (segmentsAttr.size() > 1)
1353 return failure();
1354 }
1355
1356 // Remove empty kernel environment.
1357 // Preserve synchronization by creating acc.wait operation if needed.
1358 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1359 rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
1360 /*asyncOperand=*/Value(),
1361 /*waitDevnum=*/Value(),
1362 /*async=*/nullptr,
1363 /*ifCond=*/Value());
1364 else
1365 rewriter.eraseOp(op);
1366
1367 return success();
1368 }
1369};
1370
1371//===----------------------------------------------------------------------===//
1372// Recipe Region Helpers
1373//===----------------------------------------------------------------------===//
1374
1375/// Create and populate an init region for privatization recipes.
1376/// Returns success if the region is populated, failure otherwise.
1377/// Sets needsFree to indicate if the allocated memory requires deallocation.
1378static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1379 Region &initRegion, Type varType,
1380 StringRef varName, ValueRange bounds,
1381 bool &needsFree) {
1382 // Create init block with arguments: original value + bounds
1383 SmallVector<Type> argTypes{varType};
1384 SmallVector<Location> argLocs{loc};
1385 for (Value bound : bounds) {
1386 argTypes.push_back(bound.getType());
1387 argLocs.push_back(loc);
1388 }
1389
1390 Block *initBlock = builder.createBlock(&initRegion);
1391 initBlock->addArguments(argTypes, argLocs);
1392 builder.setInsertionPointToStart(initBlock);
1393
1394 Value privatizedValue;
1395
1396 // Get the block argument that represents the original variable
1397 Value blockArgVar = initBlock->getArgument(0);
1398
1399 // Generate init region body based on variable type
1400 if (isa<MappableType>(varType)) {
1401 auto mappableTy = cast<MappableType>(varType);
1402 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1403 privatizedValue = mappableTy.generatePrivateInit(
1404 builder, loc, typedVar, varName, bounds, {}, needsFree);
1405 if (!privatizedValue)
1406 return failure();
1407 } else {
1408 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1409 auto pointerLikeTy = cast<PointerLikeType>(varType);
1410 // Use PointerLikeType's allocation API with the block argument
1411 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1412 blockArgVar, needsFree);
1413 if (!privatizedValue)
1414 return failure();
1415 }
1416
1417 // Add yield operation to init block
1418 acc::YieldOp::create(builder, loc, privatizedValue);
1419
1420 return success();
1421}
1422
1423/// Create and populate a copy region for firstprivate recipes.
1424/// Returns success if the region is populated, failure otherwise.
1425/// TODO: Handle MappableType - it does not yet have a copy API.
1426static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1427 Region &copyRegion, Type varType,
1428 ValueRange bounds) {
1429 // Create copy block with arguments: original value + privatized value +
1430 // bounds
1431 SmallVector<Type> copyArgTypes{varType, varType};
1432 SmallVector<Location> copyArgLocs{loc, loc};
1433 for (Value bound : bounds) {
1434 copyArgTypes.push_back(bound.getType());
1435 copyArgLocs.push_back(loc);
1436 }
1437
1438 Block *copyBlock = builder.createBlock(&copyRegion);
1439 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1440 builder.setInsertionPointToStart(copyBlock);
1441
1442 bool isMappable = isa<MappableType>(varType);
1443 bool isPointerLike = isa<PointerLikeType>(varType);
1444 // TODO: Handle MappableType - it does not yet have a copy API.
1445 // Otherwise, for now just fallback to pointer-like behavior.
1446 if (isMappable && !isPointerLike)
1447 return failure();
1448
1449 // Generate copy region body based on variable type
1450 if (isPointerLike) {
1451 auto pointerLikeTy = cast<PointerLikeType>(varType);
1452 Value originalArg = copyBlock->getArgument(0);
1453 Value privatizedArg = copyBlock->getArgument(1);
1454
1455 // Generate copy operation using PointerLikeType interface
1456 if (!pointerLikeTy.genCopy(
1457 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1458 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1459 return failure();
1460 }
1461
1462 // Add terminator to copy block
1463 acc::TerminatorOp::create(builder, loc);
1464
1465 return success();
1466}
1467
1468/// Create and populate a destroy region for privatization recipes.
1469/// Returns success if the region is populated, failure otherwise.
1470static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1471 Region &destroyRegion, Type varType,
1472 Value allocRes, ValueRange bounds) {
1473 // Create destroy block with arguments: original value + privatized value +
1474 // bounds
1475 SmallVector<Type> destroyArgTypes{varType, varType};
1476 SmallVector<Location> destroyArgLocs{loc, loc};
1477 for (Value bound : bounds) {
1478 destroyArgTypes.push_back(bound.getType());
1479 destroyArgLocs.push_back(loc);
1480 }
1481
1482 Block *destroyBlock = builder.createBlock(&destroyRegion);
1483 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1484 builder.setInsertionPointToStart(destroyBlock);
1485
1486 auto varToFree =
1487 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1488 if (isa<MappableType>(varType)) {
1489 auto mappableTy = cast<MappableType>(varType);
1490 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds))
1491 return failure();
1492 } else {
1493 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1494 auto pointerLikeTy = cast<PointerLikeType>(varType);
1495 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1496 return failure();
1497 }
1498
1499 acc::TerminatorOp::create(builder, loc);
1500 return success();
1501}
1502
1503} // namespace
1504
1505//===----------------------------------------------------------------------===//
1506// PrivateRecipeOp
1507//===----------------------------------------------------------------------===//
1508
1510 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1511 Type type, bool verifyYield, bool optional = false) {
1512 if (optional && region.empty())
1513 return success();
1514
1515 if (region.empty())
1516 return op->emitOpError() << "expects non-empty " << regionName << " region";
1517 Block &firstBlock = region.front();
1518 if (firstBlock.getNumArguments() < 1 ||
1519 firstBlock.getArgument(0).getType() != type)
1520 return op->emitOpError() << "expects " << regionName
1521 << " region first "
1522 "argument of the "
1523 << regionType << " type";
1524
1525 if (verifyYield) {
1526 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1527 if (yieldOp.getOperands().size() != 1 ||
1528 yieldOp.getOperands().getTypes()[0] != type)
1529 return op->emitOpError() << "expects " << regionName
1530 << " region to "
1531 "yield a value of the "
1532 << regionType << " type";
1533 }
1534 }
1535 return success();
1536}
1537
1538LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1539 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1540 "privatization", "init", getType(),
1541 /*verifyYield=*/false)))
1542 return failure();
1544 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1545 /*verifyYield=*/false, /*optional=*/true)))
1546 return failure();
1547 return success();
1548}
1549
1550std::optional<PrivateRecipeOp>
1551PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1552 StringRef recipeName, Type varType,
1553 StringRef varName, ValueRange bounds) {
1554 // First, validate that we can handle this variable type
1555 bool isMappable = isa<MappableType>(varType);
1556 bool isPointerLike = isa<PointerLikeType>(varType);
1557
1558 // Unsupported type
1559 if (!isMappable && !isPointerLike)
1560 return std::nullopt;
1561
1562 OpBuilder::InsertionGuard guard(builder);
1563
1564 // Create the recipe operation first so regions have proper parent context
1565 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1566
1567 // Populate the init region
1568 bool needsFree = false;
1569 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1570 varName, bounds, needsFree))) {
1571 recipe.erase();
1572 return std::nullopt;
1573 }
1574
1575 // Only create destroy region if the allocation needs deallocation
1576 if (needsFree) {
1577 // Extract the allocated value from the init block's yield operation
1578 auto yieldOp =
1579 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1580 Value allocRes = yieldOp.getOperand(0);
1581
1582 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1583 varType, allocRes, bounds))) {
1584 recipe.erase();
1585 return std::nullopt;
1586 }
1587 }
1588
1589 return recipe;
1590}
1591
1592std::optional<PrivateRecipeOp>
1593PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1594 StringRef recipeName,
1595 FirstprivateRecipeOp firstprivRecipe) {
1596 // Create the private.recipe op with the same type as the firstprivate.recipe.
1597 OpBuilder::InsertionGuard guard(builder);
1598 auto varType = firstprivRecipe.getType();
1599 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1600
1601 // Clone the init region
1602 IRMapping mapping;
1603 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1604
1605 // Clone destroy region if the firstprivate.recipe has one.
1606 if (!firstprivRecipe.getDestroyRegion().empty()) {
1607 IRMapping mapping;
1608 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1609 mapping);
1610 }
1611 return recipe;
1612}
1613
1614//===----------------------------------------------------------------------===//
1615// FirstprivateRecipeOp
1616//===----------------------------------------------------------------------===//
1617
1618LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1619 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1620 "privatization", "init", getType(),
1621 /*verifyYield=*/false)))
1622 return failure();
1623
1624 if (getCopyRegion().empty())
1625 return emitOpError() << "expects non-empty copy region";
1626
1627 Block &firstBlock = getCopyRegion().front();
1628 if (firstBlock.getNumArguments() < 2 ||
1629 firstBlock.getArgument(0).getType() != getType())
1630 return emitOpError() << "expects copy region with two arguments of the "
1631 "privatization type";
1632
1633 if (getDestroyRegion().empty())
1634 return success();
1635
1636 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1637 "privatization", "destroy",
1638 getType(), /*verifyYield=*/false)))
1639 return failure();
1640
1641 return success();
1642}
1643
1644std::optional<FirstprivateRecipeOp>
1645FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1646 StringRef recipeName, Type varType,
1647 StringRef varName, ValueRange bounds) {
1648 // First, validate that we can handle this variable type
1649 bool isMappable = isa<MappableType>(varType);
1650 bool isPointerLike = isa<PointerLikeType>(varType);
1651
1652 // Unsupported type
1653 if (!isMappable && !isPointerLike)
1654 return std::nullopt;
1655
1656 OpBuilder::InsertionGuard guard(builder);
1657
1658 // Create the recipe operation first so regions have proper parent context
1659 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1660
1661 // Populate the init region
1662 bool needsFree = false;
1663 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1664 varName, bounds, needsFree))) {
1665 recipe.erase();
1666 return std::nullopt;
1667 }
1668
1669 // Populate the copy region
1670 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1671 bounds))) {
1672 recipe.erase();
1673 return std::nullopt;
1674 }
1675
1676 // Only create destroy region if the allocation needs deallocation
1677 if (needsFree) {
1678 // Extract the allocated value from the init block's yield operation
1679 auto yieldOp =
1680 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1681 Value allocRes = yieldOp.getOperand(0);
1682
1683 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1684 varType, allocRes, bounds))) {
1685 recipe.erase();
1686 return std::nullopt;
1687 }
1688 }
1689
1690 return recipe;
1691}
1692
1693//===----------------------------------------------------------------------===//
1694// ReductionRecipeOp
1695//===----------------------------------------------------------------------===//
1696
1697LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1698 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1699 "init", getType(),
1700 /*verifyYield=*/false)))
1701 return failure();
1702
1703 if (getCombinerRegion().empty())
1704 return emitOpError() << "expects non-empty combiner region";
1705
1706 Block &reductionBlock = getCombinerRegion().front();
1707 if (reductionBlock.getNumArguments() < 2 ||
1708 reductionBlock.getArgument(0).getType() != getType() ||
1709 reductionBlock.getArgument(1).getType() != getType())
1710 return emitOpError() << "expects combiner region with the first two "
1711 << "arguments of the reduction type";
1712
1713 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1714 if (yieldOp.getOperands().size() != 1 ||
1715 yieldOp.getOperands().getTypes()[0] != getType())
1716 return emitOpError() << "expects combiner region to yield a value "
1717 "of the reduction type";
1718 }
1719
1720 return success();
1721}
1722
1723//===----------------------------------------------------------------------===//
1724// ParallelOp
1725//===----------------------------------------------------------------------===//
1726
1727/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1728template <typename Op>
1729static LogicalResult checkDataOperands(Op op,
1730 const mlir::ValueRange &operands) {
1731 for (mlir::Value operand : operands)
1732 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1733 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1734 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1735 operand.getDefiningOp()))
1736 return op.emitError(
1737 "expect data entry/exit operation or acc.getdeviceptr "
1738 "as defining op");
1739 return success();
1740}
1741
1742template <typename OpT, typename RecipeOpT>
1743static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
1744 const mlir::ValueRange &operands,
1745 llvm::StringRef operandName) {
1747 for (mlir::Value operand : operands) {
1748 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1749 return accConstructOp->emitOpError()
1750 << "expected " << operandName << " as defining op";
1751 if (!set.insert(operand).second)
1752 return accConstructOp->emitOpError()
1753 << operandName << " operand appears more than once";
1754 }
1755 return success();
1756}
1757
1758unsigned ParallelOp::getNumDataOperands() {
1759 return getReductionOperands().size() + getPrivateOperands().size() +
1760 getFirstprivateOperands().size() + getDataClauseOperands().size();
1761}
1762
1763Value ParallelOp::getDataOperand(unsigned i) {
1764 unsigned numOptional = getAsyncOperands().size();
1765 numOptional += getNumGangs().size();
1766 numOptional += getNumWorkers().size();
1767 numOptional += getVectorLength().size();
1768 numOptional += getIfCond() ? 1 : 0;
1769 numOptional += getSelfCond() ? 1 : 0;
1770 return getOperand(getWaitOperands().size() + numOptional + i);
1771}
1772
1773template <typename Op>
1774static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1775 ArrayAttr deviceTypes,
1776 llvm::StringRef keyword) {
1777 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1778 return op.emitOpError() << keyword << " operands count must match "
1779 << keyword << " device_type count";
1780 return success();
1781}
1782
1783template <typename Op>
1785 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1786 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1787 std::size_t numOperandsInSegments = 0;
1788 std::size_t nbOfSegments = 0;
1789
1790 if (segments) {
1791 for (auto segCount : segments.asArrayRef()) {
1792 if (maxInSegment != 0 && segCount > maxInSegment)
1793 return op.emitOpError() << keyword << " expects a maximum of "
1794 << maxInSegment << " values per segment";
1795 numOperandsInSegments += segCount;
1796 ++nbOfSegments;
1797 }
1798 }
1799
1800 if ((numOperandsInSegments != operands.size()) ||
1801 (!deviceTypes && !operands.empty()))
1802 return op.emitOpError()
1803 << keyword << " operand count does not match count in segments";
1804 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1805 return op.emitOpError()
1806 << keyword << " segment count does not match device_type count";
1807 return success();
1808}
1809
1810LogicalResult acc::ParallelOp::verify() {
1811 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
1812 mlir::acc::PrivateRecipeOp>(
1813 *this, getPrivateOperands(), "private")))
1814 return failure();
1815 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
1816 mlir::acc::FirstprivateRecipeOp>(
1817 *this, getFirstprivateOperands(), "firstprivate")))
1818 return failure();
1819 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
1820 mlir::acc::ReductionRecipeOp>(
1821 *this, getReductionOperands(), "reduction")))
1822 return failure();
1823
1825 *this, getNumGangs(), getNumGangsSegmentsAttr(),
1826 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1827 return failure();
1828
1830 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1831 getWaitOperandsDeviceTypeAttr(), "wait")))
1832 return failure();
1833
1834 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1835 getNumWorkersDeviceTypeAttr(),
1836 "num_workers")))
1837 return failure();
1838
1839 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1840 getVectorLengthDeviceTypeAttr(),
1841 "vector_length")))
1842 return failure();
1843
1845 getAsyncOperandsDeviceTypeAttr(),
1846 "async")))
1847 return failure();
1848
1850 return failure();
1851
1852 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1853}
1854
1855static mlir::Value
1856getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1858 mlir::acc::DeviceType deviceType) {
1859 if (!arrayAttr)
1860 return {};
1861 if (auto pos = findSegment(*arrayAttr, deviceType))
1862 return range[*pos];
1863 return {};
1864}
1865
1866bool acc::ParallelOp::hasAsyncOnly() {
1867 return hasAsyncOnly(mlir::acc::DeviceType::None);
1868}
1869
1870bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1871 return hasDeviceType(getAsyncOnly(), deviceType);
1872}
1873
1874mlir::Value acc::ParallelOp::getAsyncValue() {
1875 return getAsyncValue(mlir::acc::DeviceType::None);
1876}
1877
1878mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1880 getAsyncOperands(), deviceType);
1881}
1882
1883mlir::Value acc::ParallelOp::getNumWorkersValue() {
1884 return getNumWorkersValue(mlir::acc::DeviceType::None);
1885}
1886
1888acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1889 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1890 deviceType);
1891}
1892
1893mlir::Value acc::ParallelOp::getVectorLengthValue() {
1894 return getVectorLengthValue(mlir::acc::DeviceType::None);
1895}
1896
1898acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1899 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1900 getVectorLength(), deviceType);
1901}
1902
1903mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1904 return getNumGangsValues(mlir::acc::DeviceType::None);
1905}
1906
1908ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1909 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1910 getNumGangsSegments(), deviceType);
1911}
1912
1913bool acc::ParallelOp::hasWaitOnly() {
1914 return hasWaitOnly(mlir::acc::DeviceType::None);
1915}
1916
1917bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1918 return hasDeviceType(getWaitOnly(), deviceType);
1919}
1920
1921mlir::Operation::operand_range ParallelOp::getWaitValues() {
1922 return getWaitValues(mlir::acc::DeviceType::None);
1923}
1924
1926ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1928 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1929 getHasWaitDevnum(), deviceType);
1930}
1931
1932mlir::Value ParallelOp::getWaitDevnum() {
1933 return getWaitDevnum(mlir::acc::DeviceType::None);
1934}
1935
1936mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1937 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1938 getWaitOperandsSegments(), getHasWaitDevnum(),
1939 deviceType);
1940}
1941
1942void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1943 mlir::OperationState &odsState,
1944 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1945 mlir::ValueRange vectorLength,
1946 mlir::ValueRange asyncOperands,
1947 mlir::ValueRange waitOperands, mlir::Value ifCond,
1948 mlir::Value selfCond, mlir::ValueRange reductionOperands,
1949 mlir::ValueRange gangPrivateOperands,
1950 mlir::ValueRange gangFirstPrivateOperands,
1951 mlir::ValueRange dataClauseOperands) {
1952 ParallelOp::build(
1953 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1954 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1955 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1956 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1957 /*numGangsDeviceType=*/nullptr, numWorkers,
1958 /*numWorkersDeviceType=*/nullptr, vectorLength,
1959 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1960 /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
1961 gangFirstPrivateOperands, dataClauseOperands,
1962 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1963}
1964
1965void acc::ParallelOp::addNumWorkersOperand(
1966 MLIRContext *context, mlir::Value newValue,
1967 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1968 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1969 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1970 getNumWorkersMutable()));
1971}
1972void acc::ParallelOp::addVectorLengthOperand(
1973 MLIRContext *context, mlir::Value newValue,
1974 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1975 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1976 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1977 getVectorLengthMutable()));
1978}
1979
1980void acc::ParallelOp::addAsyncOnly(
1981 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1982 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1983 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1984}
1985
1986void acc::ParallelOp::addAsyncOperand(
1987 MLIRContext *context, mlir::Value newValue,
1988 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1989 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1990 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1991 getAsyncOperandsMutable()));
1992}
1993
1994void acc::ParallelOp::addNumGangsOperands(
1995 MLIRContext *context, mlir::ValueRange newValues,
1996 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1998 if (getNumGangsSegments())
1999 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2000
2001 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2002 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2003 getNumGangsMutable(), segments));
2004
2005 setNumGangsSegments(segments);
2006}
2007void acc::ParallelOp::addWaitOnly(
2008 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2009 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2010 effectiveDeviceTypes));
2011}
2012void acc::ParallelOp::addWaitOperands(
2013 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2014 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2015
2017 if (getWaitOperandsSegments())
2018 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2019
2020 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2021 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2022 getWaitOperandsMutable(), segments));
2023 setWaitOperandsSegments(segments);
2024
2026 if (getHasWaitDevnumAttr())
2027 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2028 hasDevnums.insert(
2029 hasDevnums.end(),
2030 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2031 mlir::BoolAttr::get(context, hasDevnum));
2032 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2033}
2034
2035void acc::ParallelOp::addPrivatization(MLIRContext *context,
2036 mlir::acc::PrivateOp op,
2037 mlir::acc::PrivateRecipeOp recipe) {
2038 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2039 getPrivateOperandsMutable().append(op.getResult());
2040}
2041
2042void acc::ParallelOp::addFirstPrivatization(
2043 MLIRContext *context, mlir::acc::FirstprivateOp op,
2044 mlir::acc::FirstprivateRecipeOp recipe) {
2045 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2046 getFirstprivateOperandsMutable().append(op.getResult());
2047}
2048
2049void acc::ParallelOp::addReduction(MLIRContext *context,
2050 mlir::acc::ReductionOp op,
2051 mlir::acc::ReductionRecipeOp recipe) {
2052 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2053 getReductionOperandsMutable().append(op.getResult());
2054}
2055
2056static ParseResult parseNumGangs(
2057 mlir::OpAsmParser &parser,
2059 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2060 mlir::DenseI32ArrayAttr &segments) {
2063
2064 do {
2065 if (failed(parser.parseLBrace()))
2066 return failure();
2067
2068 int32_t crtOperandsSize = operands.size();
2069 if (failed(parser.parseCommaSeparatedList(
2071 if (parser.parseOperand(operands.emplace_back()) ||
2072 parser.parseColonType(types.emplace_back()))
2073 return failure();
2074 return success();
2075 })))
2076 return failure();
2077 seg.push_back(operands.size() - crtOperandsSize);
2078
2079 if (failed(parser.parseRBrace()))
2080 return failure();
2081
2082 if (succeeded(parser.parseOptionalLSquare())) {
2083 if (parser.parseAttribute(attributes.emplace_back()) ||
2084 parser.parseRSquare())
2085 return failure();
2086 } else {
2087 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2088 parser.getContext(), mlir::acc::DeviceType::None));
2089 }
2090 } while (succeeded(parser.parseOptionalComma()));
2091
2092 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2093 attributes.end());
2094 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2095 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2096
2097 return success();
2098}
2099
2101 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2102 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2103 p << " [" << attr << "]";
2104}
2105
2107 mlir::OperandRange operands, mlir::TypeRange types,
2108 std::optional<mlir::ArrayAttr> deviceTypes,
2109 std::optional<mlir::DenseI32ArrayAttr> segments) {
2110 unsigned opIdx = 0;
2111 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2112 p << "{";
2113 llvm::interleaveComma(
2114 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2115 p << operands[opIdx] << " : " << operands[opIdx].getType();
2116 ++opIdx;
2117 });
2118 p << "}";
2119 printSingleDeviceType(p, it.value());
2120 });
2121}
2122
2124 mlir::OpAsmParser &parser,
2126 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2127 mlir::DenseI32ArrayAttr &segments) {
2130
2131 do {
2132 if (failed(parser.parseLBrace()))
2133 return failure();
2134
2135 int32_t crtOperandsSize = operands.size();
2136
2137 if (failed(parser.parseCommaSeparatedList(
2139 if (parser.parseOperand(operands.emplace_back()) ||
2140 parser.parseColonType(types.emplace_back()))
2141 return failure();
2142 return success();
2143 })))
2144 return failure();
2145
2146 seg.push_back(operands.size() - crtOperandsSize);
2147
2148 if (failed(parser.parseRBrace()))
2149 return failure();
2150
2151 if (succeeded(parser.parseOptionalLSquare())) {
2152 if (parser.parseAttribute(attributes.emplace_back()) ||
2153 parser.parseRSquare())
2154 return failure();
2155 } else {
2156 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2157 parser.getContext(), mlir::acc::DeviceType::None));
2158 }
2159 } while (succeeded(parser.parseOptionalComma()));
2160
2161 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2162 attributes.end());
2163 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2164 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2165
2166 return success();
2167}
2168
2171 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2172 std::optional<mlir::DenseI32ArrayAttr> segments) {
2173 unsigned opIdx = 0;
2174 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2175 p << "{";
2176 llvm::interleaveComma(
2177 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2178 p << operands[opIdx] << " : " << operands[opIdx].getType();
2179 ++opIdx;
2180 });
2181 p << "}";
2182 printSingleDeviceType(p, it.value());
2183 });
2184}
2185
2186static ParseResult parseWaitClause(
2187 mlir::OpAsmParser &parser,
2189 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2190 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2191 mlir::ArrayAttr &keywordOnly) {
2192 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2194
2195 bool needCommaBeforeOperands = false;
2196
2197 // Keyword only
2198 if (failed(parser.parseOptionalLParen())) {
2199 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2200 parser.getContext(), mlir::acc::DeviceType::None));
2201 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2202 return success();
2203 }
2204
2205 // Parse keyword only attributes
2206 if (succeeded(parser.parseOptionalLSquare())) {
2207 if (failed(parser.parseCommaSeparatedList([&]() {
2208 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2209 return failure();
2210 return success();
2211 })))
2212 return failure();
2213 if (parser.parseRSquare())
2214 return failure();
2215 needCommaBeforeOperands = true;
2216 }
2217
2218 if (needCommaBeforeOperands && failed(parser.parseComma()))
2219 return failure();
2220
2221 do {
2222 if (failed(parser.parseLBrace()))
2223 return failure();
2224
2225 int32_t crtOperandsSize = operands.size();
2226
2227 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2228 if (failed(parser.parseColon()))
2229 return failure();
2230 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2231 } else {
2232 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2233 }
2234
2235 if (failed(parser.parseCommaSeparatedList(
2237 if (parser.parseOperand(operands.emplace_back()) ||
2238 parser.parseColonType(types.emplace_back()))
2239 return failure();
2240 return success();
2241 })))
2242 return failure();
2243
2244 seg.push_back(operands.size() - crtOperandsSize);
2245
2246 if (failed(parser.parseRBrace()))
2247 return failure();
2248
2249 if (succeeded(parser.parseOptionalLSquare())) {
2250 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2251 parser.parseRSquare())
2252 return failure();
2253 } else {
2254 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2255 parser.getContext(), mlir::acc::DeviceType::None));
2256 }
2257 } while (succeeded(parser.parseOptionalComma()));
2258
2259 if (failed(parser.parseRParen()))
2260 return failure();
2261
2262 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2263 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2264 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2265 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2266
2267 return success();
2268}
2269
2270static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2271 if (!hasDeviceTypeValues(attrs))
2272 return false;
2273 if (attrs->size() != 1)
2274 return false;
2275 if (auto deviceTypeAttr =
2276 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2277 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2278 return false;
2279}
2280
2282 mlir::OperandRange operands, mlir::TypeRange types,
2283 std::optional<mlir::ArrayAttr> deviceTypes,
2284 std::optional<mlir::DenseI32ArrayAttr> segments,
2285 std::optional<mlir::ArrayAttr> hasDevNum,
2286 std::optional<mlir::ArrayAttr> keywordOnly) {
2287
2288 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2289 return;
2290
2291 p << "(";
2292
2293 printDeviceTypes(p, keywordOnly);
2294 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2295 p << ", ";
2296
2297 if (hasDeviceTypeValues(deviceTypes)) {
2298 unsigned opIdx = 0;
2299 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2300 p << "{";
2301 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2302 if (boolAttr && boolAttr.getValue())
2303 p << "devnum: ";
2304 llvm::interleaveComma(
2305 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2306 p << operands[opIdx] << " : " << operands[opIdx].getType();
2307 ++opIdx;
2308 });
2309 p << "}";
2310 printSingleDeviceType(p, it.value());
2311 });
2312 }
2313
2314 p << ")";
2315}
2316
2317static ParseResult parseDeviceTypeOperands(
2318 mlir::OpAsmParser &parser,
2320 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2322 if (failed(parser.parseCommaSeparatedList([&]() {
2323 if (parser.parseOperand(operands.emplace_back()) ||
2324 parser.parseColonType(types.emplace_back()))
2325 return failure();
2326 if (succeeded(parser.parseOptionalLSquare())) {
2327 if (parser.parseAttribute(attributes.emplace_back()) ||
2328 parser.parseRSquare())
2329 return failure();
2330 } else {
2331 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2332 parser.getContext(), mlir::acc::DeviceType::None));
2333 }
2334 return success();
2335 })))
2336 return failure();
2337 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2338 attributes.end());
2339 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2340 return success();
2341}
2342
2343static void
2345 mlir::OperandRange operands, mlir::TypeRange types,
2346 std::optional<mlir::ArrayAttr> deviceTypes) {
2347 if (!hasDeviceTypeValues(deviceTypes))
2348 return;
2349 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2350 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2351 printSingleDeviceType(p, std::get<0>(it));
2352 });
2353}
2354
2356 mlir::OpAsmParser &parser,
2358 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2359 mlir::ArrayAttr &keywordOnlyDeviceType) {
2360
2361 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2362 bool needCommaBeforeOperands = false;
2363
2364 if (failed(parser.parseOptionalLParen())) {
2365 // Keyword only
2366 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2367 parser.getContext(), mlir::acc::DeviceType::None));
2368 keywordOnlyDeviceType =
2369 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2370 return success();
2371 }
2372
2373 // Parse keyword only attributes
2374 if (succeeded(parser.parseOptionalLSquare())) {
2375 // Parse keyword only attributes
2376 if (failed(parser.parseCommaSeparatedList([&]() {
2377 if (parser.parseAttribute(
2378 keywordOnlyDeviceTypeAttributes.emplace_back()))
2379 return failure();
2380 return success();
2381 })))
2382 return failure();
2383 if (parser.parseRSquare())
2384 return failure();
2385 needCommaBeforeOperands = true;
2386 }
2387
2388 if (needCommaBeforeOperands && failed(parser.parseComma()))
2389 return failure();
2390
2392 if (failed(parser.parseCommaSeparatedList([&]() {
2393 if (parser.parseOperand(operands.emplace_back()) ||
2394 parser.parseColonType(types.emplace_back()))
2395 return failure();
2396 if (succeeded(parser.parseOptionalLSquare())) {
2397 if (parser.parseAttribute(attributes.emplace_back()) ||
2398 parser.parseRSquare())
2399 return failure();
2400 } else {
2401 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2402 parser.getContext(), mlir::acc::DeviceType::None));
2403 }
2404 return success();
2405 })))
2406 return failure();
2407
2408 if (failed(parser.parseRParen()))
2409 return failure();
2410
2411 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2412 attributes.end());
2413 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2414 return success();
2415}
2416
2419 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2420 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2421
2422 if (operands.begin() == operands.end() &&
2423 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2424 return;
2425 }
2426
2427 p << "(";
2428 printDeviceTypes(p, keywordOnlyDeviceTypes);
2429 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2430 hasDeviceTypeValues(deviceTypes))
2431 p << ", ";
2432 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2433 p << ")";
2434}
2435
2437 mlir::OpAsmParser &parser,
2438 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2439 mlir::Type &operandType, mlir::UnitAttr &attr) {
2440 // Keyword only
2441 if (failed(parser.parseOptionalLParen())) {
2442 attr = mlir::UnitAttr::get(parser.getContext());
2443 return success();
2444 }
2445
2447 if (failed(parser.parseOperand(op)))
2448 return failure();
2449 operand = op;
2450 if (failed(parser.parseColon()))
2451 return failure();
2452 if (failed(parser.parseType(operandType)))
2453 return failure();
2454 if (failed(parser.parseRParen()))
2455 return failure();
2456
2457 return success();
2458}
2459
2461 mlir::Operation *op,
2462 std::optional<mlir::Value> operand,
2463 mlir::Type operandType,
2464 mlir::UnitAttr attr) {
2465 if (attr)
2466 return;
2467
2468 p << "(";
2469 p.printOperand(*operand);
2470 p << " : ";
2471 p.printType(operandType);
2472 p << ")";
2473}
2474
2476 mlir::OpAsmParser &parser,
2478 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2479 // Keyword only
2480 if (failed(parser.parseOptionalLParen())) {
2481 attr = mlir::UnitAttr::get(parser.getContext());
2482 return success();
2483 }
2484
2485 if (failed(parser.parseCommaSeparatedList([&]() {
2486 if (parser.parseOperand(operands.emplace_back()))
2487 return failure();
2488 return success();
2489 })))
2490 return failure();
2491 if (failed(parser.parseColon()))
2492 return failure();
2493 if (failed(parser.parseCommaSeparatedList([&]() {
2494 if (parser.parseType(types.emplace_back()))
2495 return failure();
2496 return success();
2497 })))
2498 return failure();
2499 if (failed(parser.parseRParen()))
2500 return failure();
2501
2502 return success();
2503}
2504
2506 mlir::Operation *op,
2507 mlir::OperandRange operands,
2508 mlir::TypeRange types,
2509 mlir::UnitAttr attr) {
2510 if (attr)
2511 return;
2512
2513 p << "(";
2514 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2515 p << " : ";
2516 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2517 p << ")";
2518}
2519
2520static ParseResult
2522 mlir::acc::CombinedConstructsTypeAttr &attr) {
2523 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2524 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2525 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2526 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2527 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2528 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2529 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2530 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2531 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2532 } else {
2533 parser.emitError(parser.getCurrentLocation(),
2534 "expected compute construct name");
2535 return failure();
2536 }
2537 return success();
2538}
2539
2540static void
2542 mlir::acc::CombinedConstructsTypeAttr attr) {
2543 if (attr) {
2544 switch (attr.getValue()) {
2545 case mlir::acc::CombinedConstructsType::KernelsLoop:
2546 p << "kernels";
2547 break;
2548 case mlir::acc::CombinedConstructsType::ParallelLoop:
2549 p << "parallel";
2550 break;
2551 case mlir::acc::CombinedConstructsType::SerialLoop:
2552 p << "serial";
2553 break;
2554 };
2555 }
2556}
2557
2558//===----------------------------------------------------------------------===//
2559// SerialOp
2560//===----------------------------------------------------------------------===//
2561
2562unsigned SerialOp::getNumDataOperands() {
2563 return getReductionOperands().size() + getPrivateOperands().size() +
2564 getFirstprivateOperands().size() + getDataClauseOperands().size();
2565}
2566
2567Value SerialOp::getDataOperand(unsigned i) {
2568 unsigned numOptional = getAsyncOperands().size();
2569 numOptional += getIfCond() ? 1 : 0;
2570 numOptional += getSelfCond() ? 1 : 0;
2571 return getOperand(getWaitOperands().size() + numOptional + i);
2572}
2573
2574bool acc::SerialOp::hasAsyncOnly() {
2575 return hasAsyncOnly(mlir::acc::DeviceType::None);
2576}
2577
2578bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2579 return hasDeviceType(getAsyncOnly(), deviceType);
2580}
2581
2582mlir::Value acc::SerialOp::getAsyncValue() {
2583 return getAsyncValue(mlir::acc::DeviceType::None);
2584}
2585
2586mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2588 getAsyncOperands(), deviceType);
2589}
2590
2591bool acc::SerialOp::hasWaitOnly() {
2592 return hasWaitOnly(mlir::acc::DeviceType::None);
2593}
2594
2595bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2596 return hasDeviceType(getWaitOnly(), deviceType);
2597}
2598
2599mlir::Operation::operand_range SerialOp::getWaitValues() {
2600 return getWaitValues(mlir::acc::DeviceType::None);
2601}
2602
2604SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2606 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2607 getHasWaitDevnum(), deviceType);
2608}
2609
2610mlir::Value SerialOp::getWaitDevnum() {
2611 return getWaitDevnum(mlir::acc::DeviceType::None);
2612}
2613
2614mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2615 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2616 getWaitOperandsSegments(), getHasWaitDevnum(),
2617 deviceType);
2618}
2619
2620LogicalResult acc::SerialOp::verify() {
2621 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2622 mlir::acc::PrivateRecipeOp>(
2623 *this, getPrivateOperands(), "private")))
2624 return failure();
2625 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2626 mlir::acc::FirstprivateRecipeOp>(
2627 *this, getFirstprivateOperands(), "firstprivate")))
2628 return failure();
2629 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2630 mlir::acc::ReductionRecipeOp>(
2631 *this, getReductionOperands(), "reduction")))
2632 return failure();
2633
2635 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2636 getWaitOperandsDeviceTypeAttr(), "wait")))
2637 return failure();
2638
2640 getAsyncOperandsDeviceTypeAttr(),
2641 "async")))
2642 return failure();
2643
2645 return failure();
2646
2647 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2648}
2649
2650void acc::SerialOp::addAsyncOnly(
2651 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2652 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2653 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2654}
2655
2656void acc::SerialOp::addAsyncOperand(
2657 MLIRContext *context, mlir::Value newValue,
2658 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2659 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2660 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2661 getAsyncOperandsMutable()));
2662}
2663
2664void acc::SerialOp::addWaitOnly(
2665 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2666 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2667 effectiveDeviceTypes));
2668}
2669void acc::SerialOp::addWaitOperands(
2670 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2671 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2672
2674 if (getWaitOperandsSegments())
2675 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2676
2677 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2678 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2679 getWaitOperandsMutable(), segments));
2680 setWaitOperandsSegments(segments);
2681
2683 if (getHasWaitDevnumAttr())
2684 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2685 hasDevnums.insert(
2686 hasDevnums.end(),
2687 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2688 mlir::BoolAttr::get(context, hasDevnum));
2689 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2690}
2691
2692void acc::SerialOp::addPrivatization(MLIRContext *context,
2693 mlir::acc::PrivateOp op,
2694 mlir::acc::PrivateRecipeOp recipe) {
2695 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2696 getPrivateOperandsMutable().append(op.getResult());
2697}
2698
2699void acc::SerialOp::addFirstPrivatization(
2700 MLIRContext *context, mlir::acc::FirstprivateOp op,
2701 mlir::acc::FirstprivateRecipeOp recipe) {
2702 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2703 getFirstprivateOperandsMutable().append(op.getResult());
2704}
2705
2706void acc::SerialOp::addReduction(MLIRContext *context,
2707 mlir::acc::ReductionOp op,
2708 mlir::acc::ReductionRecipeOp recipe) {
2709 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2710 getReductionOperandsMutable().append(op.getResult());
2711}
2712
2713//===----------------------------------------------------------------------===//
2714// KernelsOp
2715//===----------------------------------------------------------------------===//
2716
2717unsigned KernelsOp::getNumDataOperands() {
2718 return getDataClauseOperands().size();
2719}
2720
2721Value KernelsOp::getDataOperand(unsigned i) {
2722 unsigned numOptional = getAsyncOperands().size();
2723 numOptional += getWaitOperands().size();
2724 numOptional += getNumGangs().size();
2725 numOptional += getNumWorkers().size();
2726 numOptional += getVectorLength().size();
2727 numOptional += getIfCond() ? 1 : 0;
2728 numOptional += getSelfCond() ? 1 : 0;
2729 return getOperand(numOptional + i);
2730}
2731
2732bool acc::KernelsOp::hasAsyncOnly() {
2733 return hasAsyncOnly(mlir::acc::DeviceType::None);
2734}
2735
2736bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2737 return hasDeviceType(getAsyncOnly(), deviceType);
2738}
2739
2740mlir::Value acc::KernelsOp::getAsyncValue() {
2741 return getAsyncValue(mlir::acc::DeviceType::None);
2742}
2743
2744mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2746 getAsyncOperands(), deviceType);
2747}
2748
2749mlir::Value acc::KernelsOp::getNumWorkersValue() {
2750 return getNumWorkersValue(mlir::acc::DeviceType::None);
2751}
2752
2754acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2755 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2756 deviceType);
2757}
2758
2759mlir::Value acc::KernelsOp::getVectorLengthValue() {
2760 return getVectorLengthValue(mlir::acc::DeviceType::None);
2761}
2762
2764acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2765 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2766 getVectorLength(), deviceType);
2767}
2768
2769mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2770 return getNumGangsValues(mlir::acc::DeviceType::None);
2771}
2772
2774KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2775 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2776 getNumGangsSegments(), deviceType);
2777}
2778
2779bool acc::KernelsOp::hasWaitOnly() {
2780 return hasWaitOnly(mlir::acc::DeviceType::None);
2781}
2782
2783bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2784 return hasDeviceType(getWaitOnly(), deviceType);
2785}
2786
2787mlir::Operation::operand_range KernelsOp::getWaitValues() {
2788 return getWaitValues(mlir::acc::DeviceType::None);
2789}
2790
2792KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2794 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2795 getHasWaitDevnum(), deviceType);
2796}
2797
2798mlir::Value KernelsOp::getWaitDevnum() {
2799 return getWaitDevnum(mlir::acc::DeviceType::None);
2800}
2801
2802mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2803 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2804 getWaitOperandsSegments(), getHasWaitDevnum(),
2805 deviceType);
2806}
2807
2808LogicalResult acc::KernelsOp::verify() {
2810 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2811 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2812 return failure();
2813
2815 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2816 getWaitOperandsDeviceTypeAttr(), "wait")))
2817 return failure();
2818
2819 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2820 getNumWorkersDeviceTypeAttr(),
2821 "num_workers")))
2822 return failure();
2823
2824 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2825 getVectorLengthDeviceTypeAttr(),
2826 "vector_length")))
2827 return failure();
2828
2830 getAsyncOperandsDeviceTypeAttr(),
2831 "async")))
2832 return failure();
2833
2835 return failure();
2836
2837 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2838}
2839
2840void acc::KernelsOp::addPrivatization(MLIRContext *context,
2841 mlir::acc::PrivateOp op,
2842 mlir::acc::PrivateRecipeOp recipe) {
2843 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2844 getPrivateOperandsMutable().append(op.getResult());
2845}
2846
2847void acc::KernelsOp::addFirstPrivatization(
2848 MLIRContext *context, mlir::acc::FirstprivateOp op,
2849 mlir::acc::FirstprivateRecipeOp recipe) {
2850 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2851 getFirstprivateOperandsMutable().append(op.getResult());
2852}
2853
2854void acc::KernelsOp::addReduction(MLIRContext *context,
2855 mlir::acc::ReductionOp op,
2856 mlir::acc::ReductionRecipeOp recipe) {
2857 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2858 getReductionOperandsMutable().append(op.getResult());
2859}
2860
2861void acc::KernelsOp::addNumWorkersOperand(
2862 MLIRContext *context, mlir::Value newValue,
2863 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2864 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2865 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2866 getNumWorkersMutable()));
2867}
2868
2869void acc::KernelsOp::addVectorLengthOperand(
2870 MLIRContext *context, mlir::Value newValue,
2871 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2872 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2873 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2874 getVectorLengthMutable()));
2875}
2876void acc::KernelsOp::addAsyncOnly(
2877 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2878 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2879 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2880}
2881
2882void acc::KernelsOp::addAsyncOperand(
2883 MLIRContext *context, mlir::Value newValue,
2884 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2885 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2886 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2887 getAsyncOperandsMutable()));
2888}
2889
2890void acc::KernelsOp::addNumGangsOperands(
2891 MLIRContext *context, mlir::ValueRange newValues,
2892 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2894 if (getNumGangsSegmentsAttr())
2895 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2896
2897 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2898 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2899 getNumGangsMutable(), segments));
2900
2901 setNumGangsSegments(segments);
2902}
2903
2904void acc::KernelsOp::addWaitOnly(
2905 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2906 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2907 effectiveDeviceTypes));
2908}
2909void acc::KernelsOp::addWaitOperands(
2910 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2911 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2912
2914 if (getWaitOperandsSegments())
2915 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2916
2917 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2918 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2919 getWaitOperandsMutable(), segments));
2920 setWaitOperandsSegments(segments);
2921
2923 if (getHasWaitDevnumAttr())
2924 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2925 hasDevnums.insert(
2926 hasDevnums.end(),
2927 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2928 mlir::BoolAttr::get(context, hasDevnum));
2929 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2930}
2931
2932//===----------------------------------------------------------------------===//
2933// HostDataOp
2934//===----------------------------------------------------------------------===//
2935
2936LogicalResult acc::HostDataOp::verify() {
2937 if (getDataClauseOperands().empty())
2938 return emitError("at least one operand must appear on the host_data "
2939 "operation");
2940
2942 for (mlir::Value operand : getDataClauseOperands()) {
2943 auto useDeviceOp =
2944 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
2945 if (!useDeviceOp)
2946 return emitError("expect data entry operation as defining op");
2947
2948 // Check for duplicate use_device clauses
2949 if (!seenVars.insert(useDeviceOp.getVar()).second)
2950 return emitError("duplicate use_device variable");
2951 }
2952 return success();
2953}
2954
2955void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2956 MLIRContext *context) {
2957 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2958}
2959
2960//===----------------------------------------------------------------------===//
2961// KernelEnvironmentOp
2962//===----------------------------------------------------------------------===//
2963
2964void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2965 RewritePatternSet &results, MLIRContext *context) {
2966 results.add<RemoveEmptyKernelEnvironment>(context);
2967}
2968
2969//===----------------------------------------------------------------------===//
2970// LoopOp
2971//===----------------------------------------------------------------------===//
2972
2973static ParseResult parseGangValue(
2974 OpAsmParser &parser, llvm::StringRef keyword,
2977 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2978 bool &needCommaBetweenValues, bool &newValue) {
2979 if (succeeded(parser.parseOptionalKeyword(keyword))) {
2980 if (parser.parseEqual())
2981 return failure();
2982 if (parser.parseOperand(operands.emplace_back()) ||
2983 parser.parseColonType(types.emplace_back()))
2984 return failure();
2985 attributes.push_back(gangArgType);
2986 needCommaBetweenValues = true;
2987 newValue = true;
2988 }
2989 return success();
2990}
2991
2992static ParseResult parseGangClause(
2993 OpAsmParser &parser,
2995 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2996 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2997 mlir::ArrayAttr &gangOnlyDeviceType) {
2998 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2999 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
3000 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
3002 bool needCommaBetweenValues = false;
3003 bool needCommaBeforeOperands = false;
3004
3005 if (failed(parser.parseOptionalLParen())) {
3006 // Gang only keyword
3007 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3008 parser.getContext(), mlir::acc::DeviceType::None));
3009 gangOnlyDeviceType =
3010 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
3011 return success();
3012 }
3013
3014 // Parse gang only attributes
3015 if (succeeded(parser.parseOptionalLSquare())) {
3016 // Parse gang only attributes
3017 if (failed(parser.parseCommaSeparatedList([&]() {
3018 if (parser.parseAttribute(
3019 gangOnlyDeviceTypeAttributes.emplace_back()))
3020 return failure();
3021 return success();
3022 })))
3023 return failure();
3024 if (parser.parseRSquare())
3025 return failure();
3026 needCommaBeforeOperands = true;
3027 }
3028
3029 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3030 mlir::acc::GangArgType::Num);
3031 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3032 mlir::acc::GangArgType::Dim);
3033 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3034 parser.getContext(), mlir::acc::GangArgType::Static);
3035
3036 do {
3037 if (needCommaBeforeOperands) {
3038 needCommaBeforeOperands = false;
3039 continue;
3040 }
3041
3042 if (failed(parser.parseLBrace()))
3043 return failure();
3044
3045 int32_t crtOperandsSize = gangOperands.size();
3046 while (true) {
3047 bool newValue = false;
3048 bool needValue = false;
3049 if (needCommaBetweenValues) {
3050 if (succeeded(parser.parseOptionalComma()))
3051 needValue = true; // expect a new value after comma.
3052 else
3053 break;
3054 }
3055
3056 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
3057 gangOperands, gangOperandsType,
3058 gangArgTypeAttributes, argNum,
3059 needCommaBetweenValues, newValue)))
3060 return failure();
3061 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
3062 gangOperands, gangOperandsType,
3063 gangArgTypeAttributes, argDim,
3064 needCommaBetweenValues, newValue)))
3065 return failure();
3066 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3067 gangOperands, gangOperandsType,
3068 gangArgTypeAttributes, argStatic,
3069 needCommaBetweenValues, newValue)))
3070 return failure();
3071
3072 if (!newValue && needValue) {
3073 parser.emitError(parser.getCurrentLocation(),
3074 "new value expected after comma");
3075 return failure();
3076 }
3077
3078 if (!newValue)
3079 break;
3080 }
3081
3082 if (gangOperands.empty())
3083 return parser.emitError(
3084 parser.getCurrentLocation(),
3085 "expect at least one of num, dim or static values");
3086
3087 if (failed(parser.parseRBrace()))
3088 return failure();
3089
3090 if (succeeded(parser.parseOptionalLSquare())) {
3091 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
3092 parser.parseRSquare())
3093 return failure();
3094 } else {
3095 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3096 parser.getContext(), mlir::acc::DeviceType::None));
3097 }
3098
3099 seg.push_back(gangOperands.size() - crtOperandsSize);
3100
3101 } while (succeeded(parser.parseOptionalComma()));
3102
3103 if (failed(parser.parseRParen()))
3104 return failure();
3105
3106 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
3107 gangArgTypeAttributes.end());
3108 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
3109 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
3110
3112 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3113 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
3114
3115 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
3116 return success();
3117}
3118
3120 mlir::OperandRange operands, mlir::TypeRange types,
3121 std::optional<mlir::ArrayAttr> gangArgTypes,
3122 std::optional<mlir::ArrayAttr> deviceTypes,
3123 std::optional<mlir::DenseI32ArrayAttr> segments,
3124 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3125
3126 if (operands.begin() == operands.end() &&
3127 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
3128 return;
3129 }
3130
3131 p << "(";
3132
3133 printDeviceTypes(p, gangOnlyDeviceTypes);
3134
3135 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
3136 hasDeviceTypeValues(deviceTypes))
3137 p << ", ";
3138
3139 if (hasDeviceTypeValues(deviceTypes)) {
3140 unsigned opIdx = 0;
3141 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
3142 p << "{";
3143 llvm::interleaveComma(
3144 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
3145 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3146 (*gangArgTypes)[opIdx]);
3147 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3148 p << LoopOp::getGangNumKeyword();
3149 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3150 p << LoopOp::getGangDimKeyword();
3151 else if (gangArgTypeAttr.getValue() ==
3152 mlir::acc::GangArgType::Static)
3153 p << LoopOp::getGangStaticKeyword();
3154 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
3155 ++opIdx;
3156 });
3157 p << "}";
3158 printSingleDeviceType(p, it.value());
3159 });
3160 }
3161 p << ")";
3162}
3163
3165 std::optional<mlir::ArrayAttr> segments,
3166 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3167 if (!segments)
3168 return false;
3169 for (auto attr : *segments) {
3170 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3171 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3172 return true;
3173 }
3174 return false;
3175}
3176
3177/// Check for duplicates in the DeviceType array attribute.
3178/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3179static std::optional<mlir::acc::DeviceType>
3180checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3181 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3182 if (!deviceTypes)
3183 return std::nullopt;
3184 for (auto attr : deviceTypes) {
3185 auto deviceTypeAttr =
3186 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3187 if (!deviceTypeAttr)
3188 return mlir::acc::DeviceType::None;
3189 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3190 return deviceTypeAttr.getValue();
3191 }
3192 return std::nullopt;
3193}
3194
3195LogicalResult acc::LoopOp::verify() {
3196 if (getUpperbound().size() != getStep().size())
3197 return emitError() << "number of upperbounds expected to be the same as "
3198 "number of steps";
3199
3200 if (getUpperbound().size() != getLowerbound().size())
3201 return emitError() << "number of upperbounds expected to be the same as "
3202 "number of lowerbounds";
3203
3204 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3205 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3206 return emitError() << "inclusiveUpperbound size is expected to be the same"
3207 << " as upperbound size";
3208
3209 // Check collapse
3210 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3211 return emitOpError() << "collapse device_type attr must be define when"
3212 << " collapse attr is present";
3213
3214 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3215 getCollapseAttr().getValue().size() !=
3216 getCollapseDeviceTypeAttr().getValue().size())
3217 return emitOpError() << "collapse attribute count must match collapse"
3218 << " device_type count";
3219 if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3220 return emitOpError() << "duplicate device_type `"
3221 << acc::stringifyDeviceType(*duplicateDeviceType)
3222 << "` found in collapseDeviceType attribute";
3223
3224 // Check gang
3225 if (!getGangOperands().empty()) {
3226 if (!getGangOperandsArgType())
3227 return emitOpError() << "gangOperandsArgType attribute must be defined"
3228 << " when gang operands are present";
3229
3230 if (getGangOperands().size() !=
3231 getGangOperandsArgTypeAttr().getValue().size())
3232 return emitOpError() << "gangOperandsArgType attribute count must match"
3233 << " gangOperands count";
3234 }
3235 if (getGangAttr()) {
3236 if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3237 return emitOpError() << "duplicate device_type `"
3238 << acc::stringifyDeviceType(*duplicateDeviceType)
3239 << "` found in gang attribute";
3240 }
3241
3243 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3244 getGangOperandsDeviceTypeAttr(), "gang")))
3245 return failure();
3246
3247 // Check worker
3248 if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3249 return emitOpError() << "duplicate device_type `"
3250 << acc::stringifyDeviceType(*duplicateDeviceType)
3251 << "` found in worker attribute";
3252 if (auto duplicateDeviceType =
3253 checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3254 return emitOpError() << "duplicate device_type `"
3255 << acc::stringifyDeviceType(*duplicateDeviceType)
3256 << "` found in workerNumOperandsDeviceType attribute";
3257 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3258 getWorkerNumOperandsDeviceTypeAttr(),
3259 "worker")))
3260 return failure();
3261
3262 // Check vector
3263 if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3264 return emitOpError() << "duplicate device_type `"
3265 << acc::stringifyDeviceType(*duplicateDeviceType)
3266 << "` found in vector attribute";
3267 if (auto duplicateDeviceType =
3268 checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3269 return emitOpError() << "duplicate device_type `"
3270 << acc::stringifyDeviceType(*duplicateDeviceType)
3271 << "` found in vectorOperandsDeviceType attribute";
3272 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3273 getVectorOperandsDeviceTypeAttr(),
3274 "vector")))
3275 return failure();
3276
3278 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3279 getTileOperandsDeviceTypeAttr(), "tile")))
3280 return failure();
3281
3282 // auto, independent and seq attribute are mutually exclusive.
3283 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3284 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3285 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3286 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3287 return emitError() << "only one of auto, independent, seq can be present "
3288 "at the same time";
3289 }
3290
3291 // Check that at least one of auto, independent, or seq is present
3292 // for the device-independent default clauses.
3293 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3294 return attr.getValue() == mlir::acc::DeviceType::None;
3295 };
3296 bool hasDefaultSeq =
3297 getSeqAttr()
3298 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3299 hasDeviceNone)
3300 : false;
3301 bool hasDefaultIndependent =
3302 getIndependentAttr()
3303 ? llvm::any_of(
3304 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3305 hasDeviceNone)
3306 : false;
3307 bool hasDefaultAuto =
3308 getAuto_Attr()
3309 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3310 hasDeviceNone)
3311 : false;
3312 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3313 return emitError()
3314 << "at least one of auto, independent, seq must be present";
3315 }
3316
3317 // Gang, worker and vector are incompatible with seq.
3318 if (getSeqAttr()) {
3319 for (auto attr : getSeqAttr()) {
3320 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3321 if (hasVector(deviceTypeAttr.getValue()) ||
3322 getVectorValue(deviceTypeAttr.getValue()) ||
3323 hasWorker(deviceTypeAttr.getValue()) ||
3324 getWorkerValue(deviceTypeAttr.getValue()) ||
3325 hasGang(deviceTypeAttr.getValue()) ||
3326 getGangValue(mlir::acc::GangArgType::Num,
3327 deviceTypeAttr.getValue()) ||
3328 getGangValue(mlir::acc::GangArgType::Dim,
3329 deviceTypeAttr.getValue()) ||
3330 getGangValue(mlir::acc::GangArgType::Static,
3331 deviceTypeAttr.getValue()))
3332 return emitError() << "gang, worker or vector cannot appear with seq";
3333 }
3334 }
3335
3336 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
3337 mlir::acc::PrivateRecipeOp>(
3338 *this, getPrivateOperands(), "private")))
3339 return failure();
3340
3341 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
3342 mlir::acc::FirstprivateRecipeOp>(
3343 *this, getFirstprivateOperands(), "firstprivate")))
3344 return failure();
3345
3346 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
3347 mlir::acc::ReductionRecipeOp>(
3348 *this, getReductionOperands(), "reduction")))
3349 return failure();
3350
3351 if (getCombined().has_value() &&
3352 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3353 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3354 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3355 return emitError("unexpected combined constructs attribute");
3356 }
3357
3358 // Check non-empty body().
3359 if (getRegion().empty())
3360 return emitError("expected non-empty body.");
3361
3362 if (getUnstructured()) {
3363 if (!isContainerLike())
3364 return emitError(
3365 "unstructured acc.loop must not have induction variables");
3366 } else if (isContainerLike()) {
3367 // When it is container-like - it is expected to hold a loop-like operation.
3368 // Obtain the maximum collapse count - we use this to check that there
3369 // are enough loops contained.
3370 uint64_t collapseCount = getCollapseValue().value_or(1);
3371 if (getCollapseAttr()) {
3372 for (auto collapseEntry : getCollapseAttr()) {
3373 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3374 if (intAttr.getValue().getZExtValue() > collapseCount)
3375 collapseCount = intAttr.getValue().getZExtValue();
3376 }
3377 }
3378
3379 // We want to check that we find enough loop-like operations inside.
3380 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3381 // level.
3382 mlir::Operation *expectedParent = this->getOperation();
3383 bool foundSibling = false;
3384 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3385 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3386 // This effectively checks that we are not looking at a sibling loop.
3387 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3388 expectedParent) {
3389 foundSibling = true;
3391 }
3392
3393 collapseCount--;
3394 expectedParent = op;
3395 }
3396 // We found enough contained loops.
3397 if (collapseCount == 0)
3400 });
3401
3402 if (foundSibling)
3403 return emitError("found sibling loops inside container-like acc.loop");
3404 if (collapseCount != 0)
3405 return emitError("failed to find enough loop-like operations inside "
3406 "container-like acc.loop");
3407 }
3408
3409 return success();
3410}
3411
3412unsigned LoopOp::getNumDataOperands() {
3413 return getReductionOperands().size() + getPrivateOperands().size() +
3414 getFirstprivateOperands().size();
3415}
3416
3417Value LoopOp::getDataOperand(unsigned i) {
3418 unsigned numOptional =
3419 getLowerbound().size() + getUpperbound().size() + getStep().size();
3420 numOptional += getGangOperands().size();
3421 numOptional += getVectorOperands().size();
3422 numOptional += getWorkerNumOperands().size();
3423 numOptional += getTileOperands().size();
3424 numOptional += getCacheOperands().size();
3425 return getOperand(numOptional + i);
3426}
3427
3428bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3429
3430bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3431 return hasDeviceType(getAuto_(), deviceType);
3432}
3433
3434bool LoopOp::hasIndependent() {
3435 return hasIndependent(mlir::acc::DeviceType::None);
3436}
3437
3438bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3439 return hasDeviceType(getIndependent(), deviceType);
3440}
3441
3442bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3443
3444bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3445 return hasDeviceType(getSeq(), deviceType);
3446}
3447
3448mlir::Value LoopOp::getVectorValue() {
3449 return getVectorValue(mlir::acc::DeviceType::None);
3450}
3451
3452mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3453 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3454 getVectorOperands(), deviceType);
3455}
3456
3457bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3458
3459bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3460 return hasDeviceType(getVector(), deviceType);
3461}
3462
3463mlir::Value LoopOp::getWorkerValue() {
3464 return getWorkerValue(mlir::acc::DeviceType::None);
3465}
3466
3467mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3468 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3469 getWorkerNumOperands(), deviceType);
3470}
3471
3472bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3473
3474bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3475 return hasDeviceType(getWorker(), deviceType);
3476}
3477
3478mlir::Operation::operand_range LoopOp::getTileValues() {
3479 return getTileValues(mlir::acc::DeviceType::None);
3480}
3481
3483LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3484 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3485 getTileOperandsSegments(), deviceType);
3486}
3487
3488std::optional<int64_t> LoopOp::getCollapseValue() {
3489 return getCollapseValue(mlir::acc::DeviceType::None);
3490}
3491
3492std::optional<int64_t>
3493LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3494 if (!getCollapseAttr())
3495 return std::nullopt;
3496 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3497 auto intAttr =
3498 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3499 return intAttr.getValue().getZExtValue();
3500 }
3501 return std::nullopt;
3502}
3503
3504mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3505 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3506}
3507
3508mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3509 mlir::acc::DeviceType deviceType) {
3510 if (getGangOperands().empty())
3511 return {};
3512 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3513 int32_t nbOperandsBefore = 0;
3514 for (unsigned i = 0; i < *pos; ++i)
3515 nbOperandsBefore += (*getGangOperandsSegments())[i];
3517 getGangOperands()
3518 .drop_front(nbOperandsBefore)
3519 .take_front((*getGangOperandsSegments())[*pos]);
3520
3521 int32_t argTypeIdx = nbOperandsBefore;
3522 for (auto value : values) {
3523 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3524 (*getGangOperandsArgType())[argTypeIdx]);
3525 if (gangArgTypeAttr.getValue() == gangArgType)
3526 return value;
3527 ++argTypeIdx;
3528 }
3529 }
3530 return {};
3531}
3532
3533bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3534
3535bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3536 return hasDeviceType(getGang(), deviceType);
3537}
3538
3539llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3540 return {&getRegion()};
3541}
3542
3543/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3544/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3545/// `(` ssa-id-and-type-list `)`
3546/// region
3547ParseResult
3550 SmallVectorImpl<Type> &lowerboundType,
3552 SmallVectorImpl<Type> &upperboundType,
3554 SmallVectorImpl<Type> &stepType) {
3555
3557 if (succeeded(
3558 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3559 if (parser.parseLParen() ||
3560 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3561 /*allowType=*/true) ||
3562 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3563 parser.parseOperandList(lowerbound, inductionVars.size(),
3565 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3566 parser.parseKeyword("to") || parser.parseLParen() ||
3567 parser.parseOperandList(upperbound, inductionVars.size(),
3569 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3570 parser.parseKeyword("step") || parser.parseLParen() ||
3571 parser.parseOperandList(step, inductionVars.size(),
3573 parser.parseColonTypeList(stepType) || parser.parseRParen())
3574 return failure();
3575 }
3576 return parser.parseRegion(region, inductionVars);
3577}
3578
3580 ValueRange lowerbound, TypeRange lowerboundType,
3581 ValueRange upperbound, TypeRange upperboundType,
3582 ValueRange steps, TypeRange stepType) {
3583 ValueRange regionArgs = region.front().getArguments();
3584 if (!regionArgs.empty()) {
3585 p << acc::LoopOp::getControlKeyword() << "(";
3586 llvm::interleaveComma(regionArgs, p,
3587 [&p](Value v) { p << v << " : " << v.getType(); });
3588 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3589 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3590 << " : " << stepType << ") ";
3591 }
3592 p.printRegion(region, /*printEntryBlockArgs=*/false);
3593}
3594
3595void acc::LoopOp::addSeq(MLIRContext *context,
3596 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3597 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3598 effectiveDeviceTypes));
3599}
3600
3601void acc::LoopOp::addIndependent(
3602 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3603 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3604 context, getIndependentAttr(), effectiveDeviceTypes));
3605}
3606
3607void acc::LoopOp::addAuto(MLIRContext *context,
3608 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3609 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3610 effectiveDeviceTypes));
3611}
3612
3613void acc::LoopOp::setCollapseForDeviceTypes(
3614 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3615 llvm::APInt value) {
3618
3619 assert((getCollapseAttr() == nullptr) ==
3620 (getCollapseDeviceTypeAttr() == nullptr));
3621 assert(value.getBitWidth() == 64);
3622
3623 if (getCollapseAttr()) {
3624 for (const auto &existing :
3625 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3626 newValues.push_back(std::get<0>(existing));
3627 newDeviceTypes.push_back(std::get<1>(existing));
3628 }
3629 }
3630
3631 if (effectiveDeviceTypes.empty()) {
3632 // If the effective device-types list is empty, this is before there are any
3633 // being applied by device_type, so this should be added as a 'none'.
3634 newValues.push_back(
3635 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3636 newDeviceTypes.push_back(
3637 acc::DeviceTypeAttr::get(context, DeviceType::None));
3638 } else {
3639 for (DeviceType dt : effectiveDeviceTypes) {
3640 newValues.push_back(
3641 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3642 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3643 }
3644 }
3645
3646 setCollapseAttr(ArrayAttr::get(context, newValues));
3647 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3648}
3649
3650void acc::LoopOp::setTileForDeviceTypes(
3651 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3652 ValueRange values) {
3654 if (getTileOperandsSegments())
3655 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3656
3657 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3658 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3659 getTileOperandsMutable(), segments));
3660
3661 setTileOperandsSegments(segments);
3662}
3663
3664void acc::LoopOp::addVectorOperand(
3665 MLIRContext *context, mlir::Value newValue,
3666 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3667 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3668 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3669 newValue, getVectorOperandsMutable()));
3670}
3671
3672void acc::LoopOp::addEmptyVector(
3673 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3674 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3675 effectiveDeviceTypes));
3676}
3677
3678void acc::LoopOp::addWorkerNumOperand(
3679 MLIRContext *context, mlir::Value newValue,
3680 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3681 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3682 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3683 newValue, getWorkerNumOperandsMutable()));
3684}
3685
3686void acc::LoopOp::addEmptyWorker(
3687 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3688 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3689 effectiveDeviceTypes));
3690}
3691
3692void acc::LoopOp::addEmptyGang(
3693 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3694 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3695 effectiveDeviceTypes));
3696}
3697
3698bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3699 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3700 return attr.getValue() == dt;
3701 };
3702 auto testFromArr = [=](ArrayAttr arr) -> bool {
3703 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3704 };
3705
3706 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3707 return true;
3708 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3709 return true;
3710 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3711 return true;
3712
3713 return false;
3714}
3715
3716bool acc::LoopOp::hasDefaultGangWorkerVector() {
3717 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3718 hasGang() || getGangValue(GangArgType::Num) ||
3719 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3720}
3721
3722acc::LoopParMode
3723acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3724 if (hasSeq(deviceType))
3725 return LoopParMode::loop_seq;
3726 if (hasAuto(deviceType))
3727 return LoopParMode::loop_auto;
3728 if (hasIndependent(deviceType))
3729 return LoopParMode::loop_independent;
3730 if (hasSeq())
3731 return LoopParMode::loop_seq;
3732 if (hasAuto())
3733 return LoopParMode::loop_auto;
3734 assert(hasIndependent() &&
3735 "loop must have default auto, seq, or independent");
3736 return LoopParMode::loop_independent;
3737}
3738
3739void acc::LoopOp::addGangOperands(
3740 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3743 if (std::optional<ArrayRef<int32_t>> existingSegments =
3744 getGangOperandsSegments())
3745 llvm::copy(*existingSegments, std::back_inserter(segments));
3746
3747 unsigned beforeCount = segments.size();
3748
3749 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3750 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3751 getGangOperandsMutable(), segments));
3752
3753 setGangOperandsSegments(segments);
3754
3755 // This is a bit of extra work to make sure we update the 'types' correctly by
3756 // adding to the types collection the correct number of times. We could
3757 // potentially add something similar to the
3758 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3759 // excessive for a one-off case.
3760 unsigned numAdded = segments.size() - beforeCount;
3761
3762 if (numAdded > 0) {
3764 if (getGangOperandsArgTypeAttr())
3765 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3766
3767 for (auto i : llvm::index_range(0u, numAdded)) {
3768 llvm::transform(argTypes, std::back_inserter(gangTypes),
3769 [=](mlir::acc::GangArgType gangTy) {
3770 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3771 });
3772 (void)i;
3773 }
3774
3775 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3776 }
3777}
3778
3779void acc::LoopOp::addPrivatization(MLIRContext *context,
3780 mlir::acc::PrivateOp op,
3781 mlir::acc::PrivateRecipeOp recipe) {
3782 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3783 getPrivateOperandsMutable().append(op.getResult());
3784}
3785
3786void acc::LoopOp::addFirstPrivatization(
3787 MLIRContext *context, mlir::acc::FirstprivateOp op,
3788 mlir::acc::FirstprivateRecipeOp recipe) {
3789 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3790 getFirstprivateOperandsMutable().append(op.getResult());
3791}
3792
3793void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3794 mlir::acc::ReductionRecipeOp recipe) {
3795 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3796 getReductionOperandsMutable().append(op.getResult());
3797}
3798
3799//===----------------------------------------------------------------------===//
3800// DataOp
3801//===----------------------------------------------------------------------===//
3802
3803LogicalResult acc::DataOp::verify() {
3804 // 2.6.5. Data Construct restriction
3805 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3806 // attach, or default clause must appear on a data construct.
3807 if (getOperands().empty() && !getDefaultAttr())
3808 return emitError("at least one operand or the default attribute "
3809 "must appear on the data operation");
3810
3811 for (mlir::Value operand : getDataClauseOperands())
3812 if (isa<BlockArgument>(operand) ||
3813 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3814 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3815 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3816 operand.getDefiningOp()))
3817 return emitError("expect data entry/exit operation or acc.getdeviceptr "
3818 "as defining op");
3819
3821 return failure();
3822
3823 return success();
3824}
3825
3826unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3827
3828Value DataOp::getDataOperand(unsigned i) {
3829 unsigned numOptional = getIfCond() ? 1 : 0;
3830 numOptional += getAsyncOperands().size() ? 1 : 0;
3831 numOptional += getWaitOperands().size();
3832 return getOperand(numOptional + i);
3833}
3834
3835bool acc::DataOp::hasAsyncOnly() {
3836 return hasAsyncOnly(mlir::acc::DeviceType::None);
3837}
3838
3839bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3840 return hasDeviceType(getAsyncOnly(), deviceType);
3841}
3842
3843mlir::Value DataOp::getAsyncValue() {
3844 return getAsyncValue(mlir::acc::DeviceType::None);
3845}
3846
3847mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3849 getAsyncOperands(), deviceType);
3850}
3851
3852bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
3853
3854bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3855 return hasDeviceType(getWaitOnly(), deviceType);
3856}
3857
3858mlir::Operation::operand_range DataOp::getWaitValues() {
3859 return getWaitValues(mlir::acc::DeviceType::None);
3860}
3861
3863DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3865 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3866 getHasWaitDevnum(), deviceType);
3867}
3868
3869mlir::Value DataOp::getWaitDevnum() {
3870 return getWaitDevnum(mlir::acc::DeviceType::None);
3871}
3872
3873mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3874 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3875 getWaitOperandsSegments(), getHasWaitDevnum(),
3876 deviceType);
3877}
3878
3879void acc::DataOp::addAsyncOnly(
3880 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3881 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3882 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3883}
3884
3885void acc::DataOp::addAsyncOperand(
3886 MLIRContext *context, mlir::Value newValue,
3887 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3888 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3889 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3890 getAsyncOperandsMutable()));
3891}
3892
3893void acc::DataOp::addWaitOnly(MLIRContext *context,
3894 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3895 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3896 effectiveDeviceTypes));
3897}
3898
3899void acc::DataOp::addWaitOperands(
3900 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3901 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3902
3904 if (getWaitOperandsSegments())
3905 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3906
3907 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3908 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3909 getWaitOperandsMutable(), segments));
3910 setWaitOperandsSegments(segments);
3911
3913 if (getHasWaitDevnumAttr())
3914 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3915 hasDevnums.insert(
3916 hasDevnums.end(),
3917 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3918 mlir::BoolAttr::get(context, hasDevnum));
3919 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3920}
3921
3922//===----------------------------------------------------------------------===//
3923// ExitDataOp
3924//===----------------------------------------------------------------------===//
3925
3926LogicalResult acc::ExitDataOp::verify() {
3927 // 2.6.6. Data Exit Directive restriction
3928 // At least one copyout, delete, or detach clause must appear on an exit data
3929 // directive.
3930 if (getDataClauseOperands().empty())
3931 return emitError("at least one operand must be present in dataOperands on "
3932 "the exit data operation");
3933
3934 // The async attribute represent the async clause without value. Therefore the
3935 // attribute and operand cannot appear at the same time.
3936 if (getAsyncOperand() && getAsync())
3937 return emitError("async attribute cannot appear with asyncOperand");
3938
3939 // The wait attribute represent the wait clause without values. Therefore the
3940 // attribute and operands cannot appear at the same time.
3941 if (!getWaitOperands().empty() && getWait())
3942 return emitError("wait attribute cannot appear with waitOperands");
3943
3944 if (getWaitDevnum() && getWaitOperands().empty())
3945 return emitError("wait_devnum cannot appear without waitOperands");
3946
3947 return success();
3948}
3949
3950unsigned ExitDataOp::getNumDataOperands() {
3951 return getDataClauseOperands().size();
3952}
3953
3954Value ExitDataOp::getDataOperand(unsigned i) {
3955 unsigned numOptional = getIfCond() ? 1 : 0;
3956 numOptional += getAsyncOperand() ? 1 : 0;
3957 numOptional += getWaitDevnum() ? 1 : 0;
3958 return getOperand(getWaitOperands().size() + numOptional + i);
3959}
3960
3961void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3962 MLIRContext *context) {
3963 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3964}
3965
3966void ExitDataOp::addAsyncOnly(MLIRContext *context,
3967 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3968 assert(effectiveDeviceTypes.empty());
3969 assert(!getAsyncAttr());
3970 assert(!getAsyncOperand());
3971
3972 setAsyncAttr(mlir::UnitAttr::get(context));
3973}
3974
3975void ExitDataOp::addAsyncOperand(
3976 MLIRContext *context, mlir::Value newValue,
3977 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3978 assert(effectiveDeviceTypes.empty());
3979 assert(!getAsyncAttr());
3980 assert(!getAsyncOperand());
3981
3982 getAsyncOperandMutable().append(newValue);
3983}
3984
3985void ExitDataOp::addWaitOnly(MLIRContext *context,
3986 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3987 assert(effectiveDeviceTypes.empty());
3988 assert(!getWaitAttr());
3989 assert(getWaitOperands().empty());
3990 assert(!getWaitDevnum());
3991
3992 setWaitAttr(mlir::UnitAttr::get(context));
3993}
3994
3995void ExitDataOp::addWaitOperands(
3996 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3997 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3998 assert(effectiveDeviceTypes.empty());
3999 assert(!getWaitAttr());
4000 assert(getWaitOperands().empty());
4001 assert(!getWaitDevnum());
4002
4003 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4004 // operands list.
4005 if (hasDevnum) {
4006 getWaitDevnumMutable().append(newValues.front());
4007 newValues = newValues.drop_front();
4008 }
4009
4010 getWaitOperandsMutable().append(newValues);
4011}
4012
4013//===----------------------------------------------------------------------===//
4014// EnterDataOp
4015//===----------------------------------------------------------------------===//
4016
4017LogicalResult acc::EnterDataOp::verify() {
4018 // 2.6.6. Data Enter Directive restriction
4019 // At least one copyin, create, or attach clause must appear on an enter data
4020 // directive.
4021 if (getDataClauseOperands().empty())
4022 return emitError("at least one operand must be present in dataOperands on "
4023 "the enter data operation");
4024
4025 // The async attribute represent the async clause without value. Therefore the
4026 // attribute and operand cannot appear at the same time.
4027 if (getAsyncOperand() && getAsync())
4028 return emitError("async attribute cannot appear with asyncOperand");
4029
4030 // The wait attribute represent the wait clause without values. Therefore the
4031 // attribute and operands cannot appear at the same time.
4032 if (!getWaitOperands().empty() && getWait())
4033 return emitError("wait attribute cannot appear with waitOperands");
4034
4035 if (getWaitDevnum() && getWaitOperands().empty())
4036 return emitError("wait_devnum cannot appear without waitOperands");
4037
4038 for (mlir::Value operand : getDataClauseOperands())
4039 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4040 operand.getDefiningOp()))
4041 return emitError("expect data entry operation as defining op");
4042
4043 return success();
4044}
4045
4046unsigned EnterDataOp::getNumDataOperands() {
4047 return getDataClauseOperands().size();
4048}
4049
4050Value EnterDataOp::getDataOperand(unsigned i) {
4051 unsigned numOptional = getIfCond() ? 1 : 0;
4052 numOptional += getAsyncOperand() ? 1 : 0;
4053 numOptional += getWaitDevnum() ? 1 : 0;
4054 return getOperand(getWaitOperands().size() + numOptional + i);
4055}
4056
4057void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4058 MLIRContext *context) {
4059 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
4060}
4061
4062void EnterDataOp::addAsyncOnly(
4063 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4064 assert(effectiveDeviceTypes.empty());
4065 assert(!getAsyncAttr());
4066 assert(!getAsyncOperand());
4067
4068 setAsyncAttr(mlir::UnitAttr::get(context));
4069}
4070
4071void EnterDataOp::addAsyncOperand(
4072 MLIRContext *context, mlir::Value newValue,
4073 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4074 assert(effectiveDeviceTypes.empty());
4075 assert(!getAsyncAttr());
4076 assert(!getAsyncOperand());
4077
4078 getAsyncOperandMutable().append(newValue);
4079}
4080
4081void EnterDataOp::addWaitOnly(MLIRContext *context,
4082 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4083 assert(effectiveDeviceTypes.empty());
4084 assert(!getWaitAttr());
4085 assert(getWaitOperands().empty());
4086 assert(!getWaitDevnum());
4087
4088 setWaitAttr(mlir::UnitAttr::get(context));
4089}
4090
4091void EnterDataOp::addWaitOperands(
4092 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4093 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4094 assert(effectiveDeviceTypes.empty());
4095 assert(!getWaitAttr());
4096 assert(getWaitOperands().empty());
4097 assert(!getWaitDevnum());
4098
4099 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4100 // operands list.
4101 if (hasDevnum) {
4102 getWaitDevnumMutable().append(newValues.front());
4103 newValues = newValues.drop_front();
4104 }
4105
4106 getWaitOperandsMutable().append(newValues);
4107}
4108
4109//===----------------------------------------------------------------------===//
4110// AtomicReadOp
4111//===----------------------------------------------------------------------===//
4112
4113LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
4114
4115//===----------------------------------------------------------------------===//
4116// AtomicWriteOp
4117//===----------------------------------------------------------------------===//
4118
4119LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
4120
4121//===----------------------------------------------------------------------===//
4122// AtomicUpdateOp
4123//===----------------------------------------------------------------------===//
4124
4125LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4126 PatternRewriter &rewriter) {
4127 if (op.isNoOp()) {
4128 rewriter.eraseOp(op);
4129 return success();
4130 }
4131
4132 if (Value writeVal = op.getWriteOpVal()) {
4133 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
4134 op.getIfCond());
4135 return success();
4136 }
4137
4138 return failure();
4139}
4140
4141LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
4142
4143LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4144
4145//===----------------------------------------------------------------------===//
4146// AtomicCaptureOp
4147//===----------------------------------------------------------------------===//
4148
4149AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4150 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4151 return op;
4152 return dyn_cast<AtomicReadOp>(getSecondOp());
4153}
4154
4155AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4156 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4157 return op;
4158 return dyn_cast<AtomicWriteOp>(getSecondOp());
4159}
4160
4161AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4162 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4163 return op;
4164 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4165}
4166
4167LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
4168
4169//===----------------------------------------------------------------------===//
4170// DeclareEnterOp
4171//===----------------------------------------------------------------------===//
4172
4173template <typename Op>
4174static LogicalResult
4176 bool requireAtLeastOneOperand = true) {
4177 if (operands.empty() && requireAtLeastOneOperand)
4178 return emitError(
4179 op->getLoc(),
4180 "at least one operand must appear on the declare operation");
4181
4182 for (mlir::Value operand : operands) {
4183 if (isa<BlockArgument>(operand) ||
4184 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4185 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4186 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4187 operand.getDefiningOp()))
4188 return op.emitError(
4189 "expect valid declare data entry operation or acc.getdeviceptr "
4190 "as defining op");
4191
4192 mlir::Value var{getVar(operand.getDefiningOp())};
4193 assert(var && "declare operands can only be data entry operations which "
4194 "must have var");
4195 (void)var;
4196 std::optional<mlir::acc::DataClause> dataClauseOptional{
4197 getDataClause(operand.getDefiningOp())};
4198 assert(dataClauseOptional.has_value() &&
4199 "declare operands can only be data entry operations which must have "
4200 "dataClause");
4201 (void)dataClauseOptional;
4202 }
4203
4204 return success();
4205}
4206
4207LogicalResult acc::DeclareEnterOp::verify() {
4208 return checkDeclareOperands(*this, this->getDataClauseOperands());
4209}
4210
4211//===----------------------------------------------------------------------===//
4212// DeclareExitOp
4213//===----------------------------------------------------------------------===//
4214
4215LogicalResult acc::DeclareExitOp::verify() {
4216 if (getToken())
4217 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4218 /*requireAtLeastOneOperand=*/false);
4219 return checkDeclareOperands(*this, this->getDataClauseOperands());
4220}
4221
4222//===----------------------------------------------------------------------===//
4223// DeclareOp
4224//===----------------------------------------------------------------------===//
4225
4226LogicalResult acc::DeclareOp::verify() {
4227 return checkDeclareOperands(*this, this->getDataClauseOperands());
4228}
4229
4230//===----------------------------------------------------------------------===//
4231// RoutineOp
4232//===----------------------------------------------------------------------===//
4233
4234static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4235 acc::DeviceType dtype) {
4236 unsigned parallelism = 0;
4237 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4238 parallelism += op.hasWorker(dtype) ? 1 : 0;
4239 parallelism += op.hasVector(dtype) ? 1 : 0;
4240 parallelism += op.hasSeq(dtype) ? 1 : 0;
4241 return parallelism;
4242}
4243
4244LogicalResult acc::RoutineOp::verify() {
4245 unsigned baseParallelism =
4246 getParallelismForDeviceType(*this, acc::DeviceType::None);
4247
4248 if (baseParallelism > 1)
4249 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4250 "be present at the same time";
4251
4252 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4253 ++dtypeInt) {
4254 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4255 if (dtype == acc::DeviceType::None)
4256 continue;
4257 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4258
4259 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4260 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4261 "be present at the same time for device_type `"
4262 << acc::stringifyDeviceType(dtype) << "`";
4263 }
4264
4265 return success();
4266}
4267
4268static ParseResult parseBindName(OpAsmParser &parser,
4269 mlir::ArrayAttr &bindIdName,
4270 mlir::ArrayAttr &bindStrName,
4271 mlir::ArrayAttr &deviceIdTypes,
4272 mlir::ArrayAttr &deviceStrTypes) {
4273 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4274 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4275 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4276 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4277
4278 if (failed(parser.parseCommaSeparatedList([&]() {
4279 mlir::Attribute newAttr;
4280 bool isSymbolRefAttr;
4281 auto parseResult = parser.parseAttribute(newAttr);
4282 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4283 bindIdNameAttrs.push_back(symbolRefAttr);
4284 isSymbolRefAttr = true;
4285 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4286 bindStrNameAttrs.push_back(stringAttr);
4287 isSymbolRefAttr = false;
4288 }
4289 if (parseResult)
4290 return failure();
4291 if (failed(parser.parseOptionalLSquare())) {
4292 if (isSymbolRefAttr) {
4293 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4294 parser.getContext(), mlir::acc::DeviceType::None));
4295 } else {
4296 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4297 parser.getContext(), mlir::acc::DeviceType::None));
4298 }
4299 } else {
4300 if (isSymbolRefAttr) {
4301 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4302 parser.parseRSquare())
4303 return failure();
4304 } else {
4305 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4306 parser.parseRSquare())
4307 return failure();
4308 }
4309 }
4310 return success();
4311 })))
4312 return failure();
4313
4314 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4315 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4316 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4317 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4318
4319 return success();
4320}
4321
4323 std::optional<mlir::ArrayAttr> bindIdName,
4324 std::optional<mlir::ArrayAttr> bindStrName,
4325 std::optional<mlir::ArrayAttr> deviceIdTypes,
4326 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4327 // Create combined vectors for all bind names and device types
4330
4331 // Append bindIdName and deviceIdTypes
4332 if (hasDeviceTypeValues(deviceIdTypes)) {
4333 allBindNames.append(bindIdName->begin(), bindIdName->end());
4334 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4335 }
4336
4337 // Append bindStrName and deviceStrTypes
4338 if (hasDeviceTypeValues(deviceStrTypes)) {
4339 allBindNames.append(bindStrName->begin(), bindStrName->end());
4340 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4341 }
4342
4343 // Print the combined sequence
4344 if (!allBindNames.empty())
4345 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4346 [&](const auto &pair) {
4347 p << std::get<0>(pair);
4348 printSingleDeviceType(p, std::get<1>(pair));
4349 });
4350}
4351
4352static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4353 mlir::ArrayAttr &gang,
4354 mlir::ArrayAttr &gangDim,
4355 mlir::ArrayAttr &gangDimDeviceTypes) {
4356
4357 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4358 gangDimDeviceTypeAttrs;
4359 bool needCommaBeforeOperands = false;
4360
4361 // Gang keyword only
4362 if (failed(parser.parseOptionalLParen())) {
4363 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4364 parser.getContext(), mlir::acc::DeviceType::None));
4365 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4366 return success();
4367 }
4368
4369 // Parse keyword only attributes
4370 if (succeeded(parser.parseOptionalLSquare())) {
4371 if (failed(parser.parseCommaSeparatedList([&]() {
4372 if (parser.parseAttribute(gangAttrs.emplace_back()))
4373 return failure();
4374 return success();
4375 })))
4376 return failure();
4377 if (parser.parseRSquare())
4378 return failure();
4379 needCommaBeforeOperands = true;
4380 }
4381
4382 if (needCommaBeforeOperands && failed(parser.parseComma()))
4383 return failure();
4384
4385 if (failed(parser.parseCommaSeparatedList([&]() {
4386 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4387 parser.parseColon() ||
4388 parser.parseAttribute(gangDimAttrs.emplace_back()))
4389 return failure();
4390 if (succeeded(parser.parseOptionalLSquare())) {
4391 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4392 parser.parseRSquare())
4393 return failure();
4394 } else {
4395 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4396 parser.getContext(), mlir::acc::DeviceType::None));
4397 }
4398 return success();
4399 })))
4400 return failure();
4401
4402 if (failed(parser.parseRParen()))
4403 return failure();
4404
4405 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4406 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4407 gangDimDeviceTypes =
4408 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4409
4410 return success();
4411}
4412
4414 std::optional<mlir::ArrayAttr> gang,
4415 std::optional<mlir::ArrayAttr> gangDim,
4416 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4417
4418 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4419 gang->size() == 1) {
4420 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4421 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4422 return;
4423 }
4424
4425 p << "(";
4426
4427 printDeviceTypes(p, gang);
4428
4429 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4430 p << ", ";
4431
4432 if (hasDeviceTypeValues(gangDimDeviceTypes))
4433 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4434 [&](const auto &pair) {
4435 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4436 p << std::get<0>(pair);
4437 printSingleDeviceType(p, std::get<1>(pair));
4438 });
4439
4440 p << ")";
4441}
4442
4443static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4444 mlir::ArrayAttr &deviceTypes) {
4446 // Keyword only
4447 if (failed(parser.parseOptionalLParen())) {
4448 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4449 parser.getContext(), mlir::acc::DeviceType::None));
4450 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4451 return success();
4452 }
4453
4454 // Parse device type attributes
4455 if (succeeded(parser.parseOptionalLSquare())) {
4456 if (failed(parser.parseCommaSeparatedList([&]() {
4457 if (parser.parseAttribute(attributes.emplace_back()))
4458 return failure();
4459 return success();
4460 })))
4461 return failure();
4462 if (parser.parseRSquare() || parser.parseRParen())
4463 return failure();
4464 }
4465 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4466 return success();
4467}
4468
4469static void
4471 std::optional<mlir::ArrayAttr> deviceTypes) {
4472
4473 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4474 auto deviceTypeAttr =
4475 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4476 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4477 return;
4478 }
4479
4480 if (!hasDeviceTypeValues(deviceTypes))
4481 return;
4482
4483 p << "([";
4484 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4485 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4486 p << dTypeAttr;
4487 });
4488 p << "])";
4489}
4490
4491bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4492
4493bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4494 return hasDeviceType(getWorker(), deviceType);
4495}
4496
4497bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4498
4499bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4500 return hasDeviceType(getVector(), deviceType);
4501}
4502
4503bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4504
4505bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4506 return hasDeviceType(getSeq(), deviceType);
4507}
4508
4509std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4510RoutineOp::getBindNameValue() {
4511 return getBindNameValue(mlir::acc::DeviceType::None);
4512}
4513
4514std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4515RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4516 if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
4517 !hasDeviceTypeValues(getBindStrNameDeviceType())) {
4518 return std::nullopt;
4519 }
4520
4521 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4522 auto attr = (*getBindIdName())[*pos];
4523 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4524 assert(symbolRefAttr && "expected SymbolRef");
4525 return symbolRefAttr;
4526 }
4527
4528 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4529 auto attr = (*getBindStrName())[*pos];
4530 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4531 assert(stringAttr && "expected String");
4532 return stringAttr;
4533 }
4534
4535 return std::nullopt;
4536}
4537
4538bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4539
4540bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4541 return hasDeviceType(getGang(), deviceType);
4542}
4543
4544std::optional<int64_t> RoutineOp::getGangDimValue() {
4545 return getGangDimValue(mlir::acc::DeviceType::None);
4546}
4547
4548std::optional<int64_t>
4549RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4550 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4551 return std::nullopt;
4552 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4553 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4554 return intAttr.getInt();
4555 }
4556 return std::nullopt;
4557}
4558
4559void RoutineOp::addSeq(MLIRContext *context,
4560 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4561 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4562 effectiveDeviceTypes));
4563}
4564
4565void RoutineOp::addVector(MLIRContext *context,
4566 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4567 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4568 effectiveDeviceTypes));
4569}
4570
4571void RoutineOp::addWorker(MLIRContext *context,
4572 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4573 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4574 effectiveDeviceTypes));
4575}
4576
4577void RoutineOp::addGang(MLIRContext *context,
4578 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4579 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4580 effectiveDeviceTypes));
4581}
4582
4583void RoutineOp::addGang(MLIRContext *context,
4584 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4585 uint64_t val) {
4588
4589 if (getGangDimAttr())
4590 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4591 if (getGangDimDeviceTypeAttr())
4592 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4593
4594 assert(dimValues.size() == deviceTypes.size());
4595
4596 if (effectiveDeviceTypes.empty()) {
4597 dimValues.push_back(
4598 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4599 deviceTypes.push_back(
4600 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4601 } else {
4602 for (DeviceType dt : effectiveDeviceTypes) {
4603 dimValues.push_back(
4604 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4605 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4606 }
4607 }
4608 assert(dimValues.size() == deviceTypes.size());
4609
4610 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4611 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4612}
4613
4614void RoutineOp::addBindStrName(MLIRContext *context,
4615 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4616 mlir::StringAttr val) {
4617 unsigned before = getBindStrNameDeviceTypeAttr()
4618 ? getBindStrNameDeviceTypeAttr().size()
4619 : 0;
4620
4621 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4622 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4623 unsigned after = getBindStrNameDeviceTypeAttr().size();
4624
4626 if (getBindStrNameAttr())
4627 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4628 for (unsigned i = 0; i < after - before; ++i)
4629 vals.push_back(val);
4630
4631 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4632}
4633
4634void RoutineOp::addBindIDName(MLIRContext *context,
4635 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4636 mlir::SymbolRefAttr val) {
4637 unsigned before =
4638 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4639
4640 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4641 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4642 unsigned after = getBindIdNameDeviceTypeAttr().size();
4643
4645 if (getBindIdNameAttr())
4646 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4647 for (unsigned i = 0; i < after - before; ++i)
4648 vals.push_back(val);
4649
4650 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4651}
4652
4653//===----------------------------------------------------------------------===//
4654// InitOp
4655//===----------------------------------------------------------------------===//
4656
4657LogicalResult acc::InitOp::verify() {
4658 Operation *currOp = *this;
4659 while ((currOp = currOp->getParentOp()))
4660 if (isComputeOperation(currOp))
4661 return emitOpError("cannot be nested in a compute operation");
4662 return success();
4663}
4664
4665void acc::InitOp::addDeviceType(MLIRContext *context,
4666 mlir::acc::DeviceType deviceType) {
4668 if (getDeviceTypesAttr())
4669 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4670
4671 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4672 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4673}
4674
4675//===----------------------------------------------------------------------===//
4676// ShutdownOp
4677//===----------------------------------------------------------------------===//
4678
4679LogicalResult acc::ShutdownOp::verify() {
4680 Operation *currOp = *this;
4681 while ((currOp = currOp->getParentOp()))
4682 if (isComputeOperation(currOp))
4683 return emitOpError("cannot be nested in a compute operation");
4684 return success();
4685}
4686
4687void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4688 mlir::acc::DeviceType deviceType) {
4690 if (getDeviceTypesAttr())
4691 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4692
4693 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4694 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4695}
4696
4697//===----------------------------------------------------------------------===//
4698// SetOp
4699//===----------------------------------------------------------------------===//
4700
4701LogicalResult acc::SetOp::verify() {
4702 Operation *currOp = *this;
4703 while ((currOp = currOp->getParentOp()))
4704 if (isComputeOperation(currOp))
4705 return emitOpError("cannot be nested in a compute operation");
4706 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4707 return emitOpError("at least one default_async, device_num, or device_type "
4708 "operand must appear");
4709 return success();
4710}
4711
4712//===----------------------------------------------------------------------===//
4713// UpdateOp
4714//===----------------------------------------------------------------------===//
4715
4716LogicalResult acc::UpdateOp::verify() {
4717 // At least one of host or device should have a value.
4718 if (getDataClauseOperands().empty())
4719 return emitError("at least one value must be present in dataOperands");
4720
4722 getAsyncOperandsDeviceTypeAttr(),
4723 "async")))
4724 return failure();
4725
4727 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4728 getWaitOperandsDeviceTypeAttr(), "wait")))
4729 return failure();
4730
4732 return failure();
4733
4734 for (mlir::Value operand : getDataClauseOperands())
4735 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4736 operand.getDefiningOp()))
4737 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4738 "as defining op");
4739
4740 return success();
4741}
4742
4743unsigned UpdateOp::getNumDataOperands() {
4744 return getDataClauseOperands().size();
4745}
4746
4747Value UpdateOp::getDataOperand(unsigned i) {
4748 unsigned numOptional = getAsyncOperands().size();
4749 numOptional += getIfCond() ? 1 : 0;
4750 return getOperand(getWaitOperands().size() + numOptional + i);
4751}
4752
4753void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
4754 MLIRContext *context) {
4755 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
4756}
4757
4758bool UpdateOp::hasAsyncOnly() {
4759 return hasAsyncOnly(mlir::acc::DeviceType::None);
4760}
4761
4762bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4763 return hasDeviceType(getAsyncOnly(), deviceType);
4764}
4765
4766mlir::Value UpdateOp::getAsyncValue() {
4767 return getAsyncValue(mlir::acc::DeviceType::None);
4768}
4769
4770mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4772 return {};
4773
4774 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4775 return getAsyncOperands()[*pos];
4776
4777 return {};
4778}
4779
4780bool UpdateOp::hasWaitOnly() {
4781 return hasWaitOnly(mlir::acc::DeviceType::None);
4782}
4783
4784bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4785 return hasDeviceType(getWaitOnly(), deviceType);
4786}
4787
4788mlir::Operation::operand_range UpdateOp::getWaitValues() {
4789 return getWaitValues(mlir::acc::DeviceType::None);
4790}
4791
4793UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4795 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4796 getHasWaitDevnum(), deviceType);
4797}
4798
4799mlir::Value UpdateOp::getWaitDevnum() {
4800 return getWaitDevnum(mlir::acc::DeviceType::None);
4801}
4802
4803mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4804 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4805 getWaitOperandsSegments(), getHasWaitDevnum(),
4806 deviceType);
4807}
4808
4809void UpdateOp::addAsyncOnly(MLIRContext *context,
4810 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4811 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4812 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4813}
4814
4815void UpdateOp::addAsyncOperand(
4816 MLIRContext *context, mlir::Value newValue,
4817 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4818 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4819 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4820 getAsyncOperandsMutable()));
4821}
4822
4823void UpdateOp::addWaitOnly(MLIRContext *context,
4824 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4825 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4826 effectiveDeviceTypes));
4827}
4828
4829void UpdateOp::addWaitOperands(
4830 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4831 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4832
4834 if (getWaitOperandsSegments())
4835 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4836
4837 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4838 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4839 getWaitOperandsMutable(), segments));
4840 setWaitOperandsSegments(segments);
4841
4843 if (getHasWaitDevnumAttr())
4844 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4845 hasDevnums.insert(
4846 hasDevnums.end(),
4847 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4848 mlir::BoolAttr::get(context, hasDevnum));
4849 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4850}
4851
4852//===----------------------------------------------------------------------===//
4853// WaitOp
4854//===----------------------------------------------------------------------===//
4855
4856LogicalResult acc::WaitOp::verify() {
4857 // The async attribute represent the async clause without value. Therefore the
4858 // attribute and operand cannot appear at the same time.
4859 if (getAsyncOperand() && getAsync())
4860 return emitError("async attribute cannot appear with asyncOperand");
4861
4862 if (getWaitDevnum() && getWaitOperands().empty())
4863 return emitError("wait_devnum cannot appear without waitOperands");
4864
4865 return success();
4866}
4867
4868#define GET_OP_CLASSES
4869#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4870
4871#define GET_ATTRDEF_CLASSES
4872#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4873
4874#define GET_TYPEDEF_CLASSES
4875#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4876
4877//===----------------------------------------------------------------------===//
4878// acc dialect utilities
4879//===----------------------------------------------------------------------===//
4880
4883 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
4885 accDataClauseOp)
4886 .Case<ACC_DATA_ENTRY_OPS>(
4887 [&](auto entry) { return entry.getVarPtr(); })
4888 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4889 [&](auto exit) { return exit.getVarPtr(); })
4890 .Default([&](mlir::Operation *) {
4892 })};
4893 return varPtr;
4894}
4895
4897 auto varPtr{
4899 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
4900 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4901 return varPtr;
4902}
4903
4905 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
4906 .Case<ACC_DATA_ENTRY_OPS>(
4907 [&](auto entry) { return entry.getVarType(); })
4908 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4909 [&](auto exit) { return exit.getVarType(); })
4910 .Default([&](mlir::Operation *) { return mlir::Type(); })};
4911 return varType;
4912}
4913
4916 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
4918 accDataClauseOp)
4919 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4920 [&](auto dataClause) { return dataClause.getAccPtr(); })
4921 .Default([&](mlir::Operation *) {
4923 })};
4924 return accPtr;
4925}
4926
4928 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
4930 [&](auto dataClause) { return dataClause.getAccVar(); })
4931 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4932 return accPtr;
4933}
4934
4936 auto varPtrPtr{
4938 .Case<ACC_DATA_ENTRY_OPS>(
4939 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
4940 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4941 return varPtrPtr;
4942}
4943
4948 accDataClauseOp)
4949 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4951 dataClause.getBounds().begin(), dataClause.getBounds().end());
4952 })
4953 .Default([&](mlir::Operation *) {
4955 })};
4956 return bounds;
4957}
4958
4962 accDataClauseOp)
4963 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4965 dataClause.getAsyncOperands().begin(),
4966 dataClause.getAsyncOperands().end());
4967 })
4968 .Default([&](mlir::Operation *) {
4970 });
4971}
4972
4973mlir::ArrayAttr
4976 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4977 return dataClause.getAsyncOperandsDeviceTypeAttr();
4978 })
4979 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4980}
4981
4982mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
4985 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
4986 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4987}
4988
4989std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
4990 auto name{
4992 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
4993 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
4994 return {};
4995 })};
4996 return name;
4997}
4998
4999std::optional<mlir::acc::DataClause>
5001 auto dataClause{
5003 accDataEntryOp)
5004 .Case<ACC_DATA_ENTRY_OPS>(
5005 [&](auto entry) { return entry.getDataClause(); })
5006 .Default([&](mlir::Operation *) { return std::nullopt; })};
5007 return dataClause;
5008}
5009
5011 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
5012 .Case<ACC_DATA_ENTRY_OPS>(
5013 [&](auto entry) { return entry.getImplicit(); })
5014 .Default([&](mlir::Operation *) { return false; })};
5015 return implicit;
5016}
5017
5019 auto dataOperands{
5022 [&](auto entry) { return entry.getDataClauseOperands(); })
5023 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
5024 return dataOperands;
5025}
5026
5029 auto dataOperands{
5032 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
5033 .Default([&](mlir::Operation *) { return nullptr; })};
5034 return dataOperands;
5035}
5036
5037mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
5038 auto recipe{
5040 .Case<ACC_DATA_ENTRY_OPS>(
5041 [&](auto entry) { return entry.getRecipeAttr(); })
5042 .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
5043 return recipe;
5044}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4413
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1228
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:3164
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:1774
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4268
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
Definition OpenACC.cpp:829
static bool isComputeOperation(Operation *op)
Definition OpenACC.cpp:1242
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:609
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2270
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
Definition OpenACC.cpp:822
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:762
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:593
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:731
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2281
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2186
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:535
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4470
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:2973
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2521
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:3180
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:4175
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:669
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2475
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:553
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:651
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:685
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3548
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1729
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2317
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:1856
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:677
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:740
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:564
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:577
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2056
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:422
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:716
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3579
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:433
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4443
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4352
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2169
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2344
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2460
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2123
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2436
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:803
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:2992
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1509
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2505
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:2100
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
Definition OpenACC.cpp:695
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
Definition OpenACC.cpp:1743
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2417
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:539
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:3119
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2355
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:774
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:629
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:1784
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4234
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2106
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2541
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4322
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
false
Parses a map_entries map type from a string format back into its numeric value.
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:69
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:46
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:54
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:4927
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:4896
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4915
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5000
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:5028
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:4945
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:5018
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:4989
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:5010
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
Definition OpenACC.cpp:5037
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:4960
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
Definition OpenACC.h:167
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:4935
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:204
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4982
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:4904
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4882
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4974
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.