MLIR 23.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/IR/Matchers.h"
22#include "mlir/IR/SymbolTable.h"
23#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
28#include <variant>
29
30using namespace mlir;
31using namespace acc;
32
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
38
39namespace {
40
41static bool isScalarLikeType(Type type) {
42 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
43}
44
45/// Helper function to attach the `VarName` attribute to an operation
46/// if a variable name is provided.
47static void attachVarNameAttr(Operation *op, OpBuilder &builder,
48 StringRef varName) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
51 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
52 }
53}
54
55template <typename T>
56struct MemRefPointerLikeModel
57 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
58 Type getElementType(Type pointer) const {
59 return cast<T>(pointer).getElementType();
60 }
61
62 mlir::acc::VariableTypeCategory
63 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
64 Type varType) const {
65 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
67 }
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
70 // This memref is unranked - aka it could have any rank, including a
71 // rank of 0 which could mean scalar. For now, return uncategorized.
72 return mlir::acc::VariableTypeCategory::uncategorized;
73 }
74
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
78 }
79 // Zero-rank non-scalar - need further analysis to determine the type
80 // category. For now, return uncategorized.
81 return mlir::acc::VariableTypeCategory::uncategorized;
82 }
83
84 // It has a rank - must be an array.
85 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
87 }
88
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree) const {
92 auto memrefTy = cast<MemRefType>(pointer);
93
94 // Check if this is a static memref (all dimensions are known) - if yes
95 // then we can generate an alloca operation.
96 if (memrefTy.hasStaticShape()) {
97 needsFree = false; // alloca doesn't need deallocation
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
101 }
102
103 // For dynamic memrefs, extract sizes from the original variable if
104 // provided. Otherwise they cannot be handled.
105 if (originalVar && originalVar.getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
110 // Extract the size of dimension i from the original variable
111 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
112 auto dimSize =
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
115 }
116 // Note: We only add dynamic sizes to the dynamicSizes array
117 // Static dimensions are handled automatically by AllocOp
118 }
119 needsFree = true; // alloc needs deallocation
120 auto allocOp =
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
124 }
125
126 // TODO: Unranked not yet supported.
127 return {};
128 }
129
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
131 TypedValue<PointerLikeType> varToFree, Value allocRes,
132 Type varType) const {
133 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
134 // Use allocRes if provided to determine the allocation type
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
136
137 // Walk through casts to find the original allocation
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc = nullptr;
140
141 // Follow the chain of operations to find the original allocation
142 // even if a casted result is provided.
143 while (currentValue) {
144 if (auto *definingOp = currentValue.getDefiningOp()) {
145 // Check if this is an allocation operation
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
148 break;
149 }
150
151 // Check if this is a cast operation we can look through
152 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
154 continue;
155 }
156
157 // Check for other cast-like operations
158 if (auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
161 continue;
162 }
163
164 // If we can't look through this operation, stop
165 break;
166 }
167 // This is a block argument or similar - can't trace further.
168 break;
169 }
170
171 if (originalAlloc) {
172 if (isa<memref::AllocaOp>(originalAlloc)) {
173 // This is an alloca - no dealloc needed, but return true (success)
174 return true;
175 }
176 if (isa<memref::AllocOp>(originalAlloc)) {
177 // This is an alloc - generate dealloc on varToFree
178 memref::DeallocOp::create(builder, loc, memrefValue);
179 return true;
180 }
181 }
182 }
183
184 return false;
185 }
186
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
188 TypedValue<PointerLikeType> destination,
189 TypedValue<PointerLikeType> source, Type varType) const {
190 // Generate a copy operation between two memrefs
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
193
194 // As per memref documentation, source and destination must have same
195 // element type and shape in order to be compatible. We do not want to fail
196 // with an IR verification error - thus check that before generating the
197 // copy operation.
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
203 return true;
204 }
205
206 return false;
207 }
208
209 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType) const {
212 // Load from a memref - only valid for scalar memrefs (rank 0).
213 // This is because the address computation for memrefs is part of the load
214 // (and not computed separately), but the API does not have arguments for
215 // indexing.
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
217 if (!memrefValue)
218 return {};
219
220 auto memrefTy = memrefValue.getType();
221
222 // Only load from scalar memrefs (rank 0)
223 if (memrefTy.getRank() != 0)
224 return {};
225
226 return memref::LoadOp::create(builder, loc, memrefValue);
227 }
228
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
230 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
231 // Store to a memref - only valid for scalar memrefs (rank 0)
232 // This is because the address computation for memrefs is part of the store
233 // (and not computed separately), but the API does not have arguments for
234 // indexing.
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
236 if (!memrefValue)
237 return false;
238
239 auto memrefTy = memrefValue.getType();
240
241 // Only store to scalar memrefs (rank 0)
242 if (memrefTy.getRank() != 0)
243 return false;
244
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
246 return true;
247 }
248
249 bool isDeviceData(Type pointer, Value var) const {
250 auto memrefTy = cast<T>(pointer);
251 Attribute memSpace = memrefTy.getMemorySpace();
252 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
253 }
254};
255
256struct LLVMPointerPointerLikeModel
257 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
258 LLVM::LLVMPointerType> {
259 Type getElementType(Type pointer) const { return Type(); }
260
261 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
263 Type valueType) const {
264 // For LLVM pointers, we need the valueType to determine what to load
265 if (!valueType)
266 return {};
267
268 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
269 }
270
271 bool genStore(Type pointer, OpBuilder &builder, Location loc,
272 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
273 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
274 return true;
275 }
276};
277
278struct MemrefAddressOfGlobalModel
279 : public AddressOfGlobalOpInterface::ExternalModel<
280 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
281 SymbolRefAttr getSymbol(Operation *op) const {
282 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
283 return getGlobalOp.getNameAttr();
284 }
285};
286
287struct MemrefGlobalVariableModel
288 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
289 memref::GlobalOp> {
290 bool isConstant(Operation *op) const {
291 auto globalOp = cast<memref::GlobalOp>(op);
292 return globalOp.getConstant();
293 }
294
295 Region *getInitRegion(Operation *op) const {
296 // GlobalOp uses attributes for initialization, not regions
297 return nullptr;
298 }
299
300 bool isDeviceData(Operation *op) const {
301 auto globalOp = cast<memref::GlobalOp>(op);
302 Attribute memSpace = globalOp.getType().getMemorySpace();
303 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
304 }
305};
306
307struct GPULaunchOffloadRegionModel
308 : public acc::OffloadRegionOpInterface::ExternalModel<
309 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
310 mlir::Region &getOffloadRegion(mlir::Operation *op) const {
311 return cast<gpu::LaunchOp>(op).getBody();
312 }
313};
314
315/// Helper function for any of the times we need to modify an ArrayAttr based on
316/// a device type list. Returns a new ArrayAttr with all of the
317/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
318/// list is empty).
319mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
320 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
321 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
323 if (existingDeviceTypes)
324 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
325
326 if (newDeviceTypes.empty())
327 deviceTypes.push_back(
328 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
329
330 for (DeviceType dt : newDeviceTypes)
331 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
332
333 return mlir::ArrayAttr::get(context, deviceTypes);
334}
335
336/// Helper function for any of the times we need to add operands that are
337/// affected by a device type list. Returns a new ArrayAttr with all of the
338/// existingDeviceTypes, plus the effective new ones (or an added none, if the
339/// new list is empty). Additionally, adds the arguments to the argCollection
340/// the correct number of times. This will also update a 'segments' array, even
341/// if it won't be used.
342mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
343 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
344 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
345 mlir::MutableOperandRange argCollection,
346 llvm::SmallVector<int32_t> &segments) {
348 if (existingDeviceTypes)
349 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
350
351 if (newDeviceTypes.empty()) {
352 argCollection.append(arguments);
353 segments.push_back(arguments.size());
354 deviceTypes.push_back(
355 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
356 }
357
358 for (DeviceType dt : newDeviceTypes) {
359 argCollection.append(arguments);
360 segments.push_back(arguments.size());
361 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
362 }
363
364 return mlir::ArrayAttr::get(context, deviceTypes);
365}
366
367/// Overload for when the 'segments' aren't needed.
368mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
369 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
370 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
371 mlir::MutableOperandRange argCollection) {
373 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
374 newDeviceTypes, arguments,
375 argCollection, segments);
376}
377} // namespace
378
379//===----------------------------------------------------------------------===//
380// OpenACC operations
381//===----------------------------------------------------------------------===//
382
383void OpenACCDialect::initialize() {
384 addOperations<
385#define GET_OP_LIST
386#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
387 >();
388 addAttributes<
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
391 >();
392 addTypes<
393#define GET_TYPEDEF_LIST
394#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
395 >();
396
397 // By attaching interfaces here, we make the OpenACC dialect dependent on
398 // the other dialects. This is probably better than having dialects like LLVM
399 // and memref be dependent on OpenACC.
400 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
401 *getContext());
402 UnrankedMemRefType::attachInterface<
403 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
404 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
405 *getContext());
406
407 // Attach operation interfaces
408 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
409 *getContext());
410 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
411 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext());
412}
413
414//===----------------------------------------------------------------------===//
415// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
416// acc.kernel_environment / acc.data / acc.host_data / acc.loop
417//===----------------------------------------------------------------------===//
418
419/// Generic helper for single-region OpenACC ops that execute their body once
420/// and then return to the parent operation with their results (if any).
421static void
423 RegionBranchPoint point,
425 if (point.isParent()) {
426 regions.push_back(RegionSuccessor(&region));
427 return;
428 }
429
430 regions.push_back(RegionSuccessor::parent());
431}
432
434 RegionSuccessor successor) {
435 return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
436}
437
438void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
440 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
441 regions);
442}
443
444ValueRange KernelsOp::getSuccessorInputs(RegionSuccessor successor) {
445 return getSingleRegionSuccessorInputs(getOperation(), successor);
446}
447
448void ParallelOp::getSuccessorRegions(
450 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
451 regions);
452}
453
454ValueRange ParallelOp::getSuccessorInputs(RegionSuccessor successor) {
455 return getSingleRegionSuccessorInputs(getOperation(), successor);
456}
457
458void SerialOp::getSuccessorRegions(RegionBranchPoint point,
460 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
461 regions);
462}
463
464ValueRange SerialOp::getSuccessorInputs(RegionSuccessor successor) {
465 return getSingleRegionSuccessorInputs(getOperation(), successor);
466}
467
468void DataOp::getSuccessorRegions(RegionBranchPoint point,
470 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
471 regions);
472}
473
474ValueRange DataOp::getSuccessorInputs(RegionSuccessor successor) {
475 return getSingleRegionSuccessorInputs(getOperation(), successor);
476}
477
478void HostDataOp::getSuccessorRegions(
480 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
481 regions);
482}
483
484ValueRange HostDataOp::getSuccessorInputs(RegionSuccessor successor) {
485 return getSingleRegionSuccessorInputs(getOperation(), successor);
486}
487
488void LoopOp::getSuccessorRegions(RegionBranchPoint point,
490 // Unstructured loops: the body may contain arbitrary CFG and early exits.
491 // At the RegionBranch level, only model entry into the body and exit to the
492 // parent; any backedges are represented inside the region CFG.
493 if (getUnstructured()) {
494 if (point.isParent()) {
495 regions.push_back(RegionSuccessor(&getRegion()));
496 return;
497 }
498 regions.push_back(RegionSuccessor::parent());
499 return;
500 }
501
502 // Structured loops: model a loop-shaped region graph similar to scf.for.
503 regions.push_back(RegionSuccessor(&getRegion()));
504 regions.push_back(RegionSuccessor::parent());
505}
506
507ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) {
508 return getSingleRegionSuccessorInputs(getOperation(), successor);
509}
510
511//===----------------------------------------------------------------------===//
512// RegionBranchTerminatorOpInterface
513//===----------------------------------------------------------------------===//
514
516TerminatorOp::getMutableSuccessorOperands(RegionSuccessor /*point*/) {
517 // `acc.terminator` does not forward operands.
518 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
519}
520
521//===----------------------------------------------------------------------===//
522// device_type support helpers
523//===----------------------------------------------------------------------===//
524
525static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
526 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
527}
528
529static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
530 mlir::acc::DeviceType deviceType) {
531 if (!hasDeviceTypeValues(arrayAttr))
532 return false;
533
534 for (auto attr : *arrayAttr) {
535 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
536 if (deviceTypeAttr.getValue() == deviceType)
537 return true;
538 }
539
540 return false;
541}
542
544 std::optional<mlir::ArrayAttr> deviceTypes) {
545 if (!hasDeviceTypeValues(deviceTypes))
546 return;
547
548 p << "[";
549 llvm::interleaveComma(*deviceTypes, p,
550 [&](mlir::Attribute attr) { p << attr; });
551 p << "]";
552}
553
554static std::optional<unsigned> findSegment(ArrayAttr segments,
555 mlir::acc::DeviceType deviceType) {
556 unsigned segmentIdx = 0;
557 for (auto attr : segments) {
558 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
559 if (deviceTypeAttr.getValue() == deviceType)
560 return std::make_optional(segmentIdx);
561 ++segmentIdx;
562 }
563 return std::nullopt;
564}
565
567getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
569 std::optional<llvm::ArrayRef<int32_t>> segments,
570 mlir::acc::DeviceType deviceType) {
571 if (!arrayAttr)
572 return range.take_front(0);
573 if (auto pos = findSegment(*arrayAttr, deviceType)) {
574 int32_t nbOperandsBefore = 0;
575 for (unsigned i = 0; i < *pos; ++i)
576 nbOperandsBefore += (*segments)[i];
577 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
578 }
579 return range.take_front(0);
580}
581
582static mlir::Value
583getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
585 std::optional<llvm::ArrayRef<int32_t>> segments,
586 std::optional<mlir::ArrayAttr> hasWaitDevnum,
587 mlir::acc::DeviceType deviceType) {
588 if (!hasDeviceTypeValues(deviceTypeAttr))
589 return {};
590 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
591 if (hasWaitDevnum->getValue()[*pos])
592 return getValuesFromSegments(deviceTypeAttr, operands, segments,
593 deviceType)
594 .front();
595 return {};
596}
597
599getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
601 std::optional<llvm::ArrayRef<int32_t>> segments,
602 std::optional<mlir::ArrayAttr> hasWaitDevnum,
603 mlir::acc::DeviceType deviceType) {
604 auto range =
605 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
606 if (range.empty())
607 return range;
608 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
609 if (hasWaitDevnum && *hasWaitDevnum) {
610 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
611 if (boolAttr.getValue())
612 return range.drop_front(1); // first value is devnum
613 }
614 }
615 return range;
616}
617
618template <typename Op>
619static LogicalResult checkWaitAndAsyncConflict(Op op) {
620 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
621 ++dtypeInt) {
622 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
623
624 // The asyncOnly attribute represent the async clause without value.
625 // Therefore the attribute and operand cannot appear at the same time.
626 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
627 op.hasAsyncOnly(dtype))
628 return op.emitError(
629 "asyncOnly attribute cannot appear with asyncOperand");
630
631 // The wait attribute represent the wait clause without values. Therefore
632 // the attribute and operands cannot appear at the same time.
633 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
634 op.hasWaitOnly(dtype))
635 return op.emitError("wait attribute cannot appear with waitOperands");
636 }
637 return success();
638}
639
640template <typename Op>
641static LogicalResult checkVarAndVarType(Op op) {
642 if (!op.getVar())
643 return op.emitError("must have var operand");
644
645 // A variable must have a type that is either pointer-like or mappable.
646 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
647 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
648 return op.emitError("var must be mappable or pointer-like");
649
650 // When it is a pointer-like type, the varType must capture the target type.
651 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
652 op.getVarType() == op.getVar().getType())
653 return op.emitError("varType must capture the element type of var");
654
655 return success();
656}
657
658template <typename Op>
659static LogicalResult checkVarAndAccVar(Op op) {
660 if (op.getVar().getType() != op.getAccVar().getType())
661 return op.emitError("input and output types must match");
662
663 return success();
664}
665
666template <typename Op>
667static LogicalResult checkNoModifier(Op op) {
668 if (op.getModifiers() != acc::DataClauseModifier::none)
669 return op.emitError("no data clause modifiers are allowed");
670 return success();
671}
672
673template <typename Op>
674static LogicalResult
675checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
676 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
677 return op.emitError(
678 "invalid data clause modifiers: " +
679 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
680
681 return success();
682}
683
684template <typename OpT, typename RecipeOpT>
685static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
686 // Mappable types do not need a recipe because it is possible to generate one
687 // from its API. Reject reductions though because no API is available for them
688 // at this time.
689 if (mlir::acc::isMappableType(op.getVar().getType()) &&
690 !std::is_same_v<OpT, acc::ReductionOp>)
691 return success();
692
693 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
694 if (!operandRecipe)
695 return op->emitOpError() << "recipe expected for " << operandName;
696
697 auto decl =
699 if (!decl)
700 return op->emitOpError()
701 << "expected symbol reference " << operandRecipe << " to point to a "
702 << operandName << " declaration";
703 return success();
704}
705
706static ParseResult parseVar(mlir::OpAsmParser &parser,
708 // Either `var` or `varPtr` keyword is required.
709 if (failed(parser.parseOptionalKeyword("varPtr"))) {
710 if (failed(parser.parseKeyword("var")))
711 return failure();
712 }
713 if (failed(parser.parseLParen()))
714 return failure();
715 if (failed(parser.parseOperand(var)))
716 return failure();
717
718 return success();
719}
720
722 mlir::Value var) {
723 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
724 p << "varPtr(";
725 else
726 p << "var(";
727 p.printOperand(var);
728}
729
730static ParseResult parseAccVar(mlir::OpAsmParser &parser,
732 mlir::Type &accVarType) {
733 // Either `accVar` or `accPtr` keyword is required.
734 if (failed(parser.parseOptionalKeyword("accPtr"))) {
735 if (failed(parser.parseKeyword("accVar")))
736 return failure();
737 }
738 if (failed(parser.parseLParen()))
739 return failure();
740 if (failed(parser.parseOperand(var)))
741 return failure();
742 if (failed(parser.parseColon()))
743 return failure();
744 if (failed(parser.parseType(accVarType)))
745 return failure();
746 if (failed(parser.parseRParen()))
747 return failure();
748
749 return success();
750}
751
753 mlir::Value accVar, mlir::Type accVarType) {
754 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
755 p << "accPtr(";
756 else
757 p << "accVar(";
758 p.printOperand(accVar);
759 p << " : ";
760 p.printType(accVarType);
761 p << ")";
762}
763
764static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
765 mlir::Type &varPtrType,
766 mlir::TypeAttr &varTypeAttr) {
767 if (failed(parser.parseType(varPtrType)))
768 return failure();
769 if (failed(parser.parseRParen()))
770 return failure();
771
772 if (succeeded(parser.parseOptionalKeyword("varType"))) {
773 if (failed(parser.parseLParen()))
774 return failure();
775 mlir::Type varType;
776 if (failed(parser.parseType(varType)))
777 return failure();
778 varTypeAttr = mlir::TypeAttr::get(varType);
779 if (failed(parser.parseRParen()))
780 return failure();
781 } else {
782 // Set `varType` from the element type of the type of `varPtr`.
783 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
784 varTypeAttr = mlir::TypeAttr::get(
785 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
786 else
787 varTypeAttr = mlir::TypeAttr::get(varPtrType);
788 }
789
790 return success();
791}
792
794 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
795 p.printType(varPtrType);
796 p << ")";
797
798 // Print the `varType` only if it differs from the element type of
799 // `varPtr`'s type.
800 mlir::Type varType = varTypeAttr.getValue();
801 mlir::Type typeToCheckAgainst =
802 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
803 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
804 : varPtrType;
805 if (typeToCheckAgainst != varType) {
806 p << " varType(";
807 p.printType(varType);
808 p << ")";
809 }
810}
811
812static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
813 mlir::SymbolRefAttr &recipeAttr) {
814 if (failed(parser.parseAttribute(recipeAttr)))
815 return failure();
816 return success();
817}
818
820 mlir::SymbolRefAttr recipeAttr) {
821 p << recipeAttr;
822}
823
824//===----------------------------------------------------------------------===//
825// DataBoundsOp
826//===----------------------------------------------------------------------===//
827LogicalResult acc::DataBoundsOp::verify() {
828 auto extent = getExtent();
829 auto upperbound = getUpperbound();
830 if (!extent && !upperbound)
831 return emitError("expected extent or upperbound.");
832 return success();
833}
834
835//===----------------------------------------------------------------------===//
836// PrivateOp
837//===----------------------------------------------------------------------===//
838LogicalResult acc::PrivateOp::verify() {
839 if (getDataClause() != acc::DataClause::acc_private)
840 return emitError(
841 "data clause associated with private operation must match its intent");
842 if (failed(checkVarAndVarType(*this)))
843 return failure();
844 if (failed(checkNoModifier(*this)))
845 return failure();
846 if (failed(
848 return failure();
849 return success();
850}
851
852//===----------------------------------------------------------------------===//
853// FirstprivateOp
854//===----------------------------------------------------------------------===//
855LogicalResult acc::FirstprivateOp::verify() {
856 if (getDataClause() != acc::DataClause::acc_firstprivate)
857 return emitError("data clause associated with firstprivate operation must "
858 "match its intent");
859 if (failed(checkVarAndVarType(*this)))
860 return failure();
861 if (failed(checkNoModifier(*this)))
862 return failure();
864 *this, "firstprivate")))
865 return failure();
866 return success();
867}
868
869//===----------------------------------------------------------------------===//
870// ReductionOp
871//===----------------------------------------------------------------------===//
872LogicalResult acc::ReductionOp::verify() {
873 if (getDataClause() != acc::DataClause::acc_reduction)
874 return emitError("data clause associated with reduction operation must "
875 "match its intent");
876 if (failed(checkVarAndVarType(*this)))
877 return failure();
878 if (failed(checkNoModifier(*this)))
879 return failure();
881 *this, "reduction")))
882 return failure();
883 return success();
884}
885
886//===----------------------------------------------------------------------===//
887// DevicePtrOp
888//===----------------------------------------------------------------------===//
889LogicalResult acc::DevicePtrOp::verify() {
890 if (getDataClause() != acc::DataClause::acc_deviceptr)
891 return emitError("data clause associated with deviceptr operation must "
892 "match its intent");
893 if (failed(checkVarAndVarType(*this)))
894 return failure();
895 if (failed(checkVarAndAccVar(*this)))
896 return failure();
897 if (failed(checkNoModifier(*this)))
898 return failure();
899 return success();
900}
901
902//===----------------------------------------------------------------------===//
903// PresentOp
904//===----------------------------------------------------------------------===//
905LogicalResult acc::PresentOp::verify() {
906 if (getDataClause() != acc::DataClause::acc_present)
907 return emitError(
908 "data clause associated with present operation must match its intent");
909 if (failed(checkVarAndVarType(*this)))
910 return failure();
911 if (failed(checkVarAndAccVar(*this)))
912 return failure();
913 if (failed(checkNoModifier(*this)))
914 return failure();
915 return success();
916}
917
918//===----------------------------------------------------------------------===//
919// CopyinOp
920//===----------------------------------------------------------------------===//
921LogicalResult acc::CopyinOp::verify() {
922 // Test for all clauses this operation can be decomposed from:
923 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
924 getDataClause() != acc::DataClause::acc_copyin_readonly &&
925 getDataClause() != acc::DataClause::acc_copy &&
926 getDataClause() != acc::DataClause::acc_reduction)
927 return emitError(
928 "data clause associated with copyin operation must match its intent"
929 " or specify original clause this operation was decomposed from");
930 if (failed(checkVarAndVarType(*this)))
931 return failure();
932 if (failed(checkVarAndAccVar(*this)))
933 return failure();
934 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
935 acc::DataClauseModifier::always |
936 acc::DataClauseModifier::capture)))
937 return failure();
938 return success();
939}
940
941bool acc::CopyinOp::isCopyinReadonly() {
942 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
943 acc::bitEnumContainsAny(getModifiers(),
944 acc::DataClauseModifier::readonly);
945}
946
947//===----------------------------------------------------------------------===//
948// CreateOp
949//===----------------------------------------------------------------------===//
950LogicalResult acc::CreateOp::verify() {
951 // Test for all clauses this operation can be decomposed from:
952 if (getDataClause() != acc::DataClause::acc_create &&
953 getDataClause() != acc::DataClause::acc_create_zero &&
954 getDataClause() != acc::DataClause::acc_copyout &&
955 getDataClause() != acc::DataClause::acc_copyout_zero)
956 return emitError(
957 "data clause associated with create operation must match its intent"
958 " or specify original clause this operation was decomposed from");
959 if (failed(checkVarAndVarType(*this)))
960 return failure();
961 if (failed(checkVarAndAccVar(*this)))
962 return failure();
963 // this op is the entry part of copyout, so it also needs to allow all
964 // modifiers allowed on copyout.
965 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
966 acc::DataClauseModifier::always |
967 acc::DataClauseModifier::capture)))
968 return failure();
969 return success();
970}
971
972bool acc::CreateOp::isCreateZero() {
973 // The zero modifier is encoded in the data clause.
974 return getDataClause() == acc::DataClause::acc_create_zero ||
975 getDataClause() == acc::DataClause::acc_copyout_zero ||
976 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
977}
978
979//===----------------------------------------------------------------------===//
980// NoCreateOp
981//===----------------------------------------------------------------------===//
982LogicalResult acc::NoCreateOp::verify() {
983 if (getDataClause() != acc::DataClause::acc_no_create)
984 return emitError("data clause associated with no_create operation must "
985 "match its intent");
986 if (failed(checkVarAndVarType(*this)))
987 return failure();
988 if (failed(checkVarAndAccVar(*this)))
989 return failure();
990 if (failed(checkNoModifier(*this)))
991 return failure();
992 return success();
993}
994
995//===----------------------------------------------------------------------===//
996// AttachOp
997//===----------------------------------------------------------------------===//
998LogicalResult acc::AttachOp::verify() {
999 if (getDataClause() != acc::DataClause::acc_attach)
1000 return emitError(
1001 "data clause associated with attach operation must match its intent");
1002 if (failed(checkVarAndVarType(*this)))
1003 return failure();
1004 if (failed(checkVarAndAccVar(*this)))
1005 return failure();
1006 if (failed(checkNoModifier(*this)))
1007 return failure();
1008 return success();
1009}
1010
1011//===----------------------------------------------------------------------===//
1012// DeclareDeviceResidentOp
1013//===----------------------------------------------------------------------===//
1014
1015LogicalResult acc::DeclareDeviceResidentOp::verify() {
1016 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
1017 return emitError("data clause associated with device_resident operation "
1018 "must match its intent");
1019 if (failed(checkVarAndVarType(*this)))
1020 return failure();
1021 if (failed(checkVarAndAccVar(*this)))
1022 return failure();
1023 if (failed(checkNoModifier(*this)))
1024 return failure();
1025 return success();
1026}
1027
1028//===----------------------------------------------------------------------===//
1029// DeclareLinkOp
1030//===----------------------------------------------------------------------===//
1031
1032LogicalResult acc::DeclareLinkOp::verify() {
1033 if (getDataClause() != acc::DataClause::acc_declare_link)
1034 return emitError(
1035 "data clause associated with link operation must match its intent");
1036 if (failed(checkVarAndVarType(*this)))
1037 return failure();
1038 if (failed(checkVarAndAccVar(*this)))
1039 return failure();
1040 if (failed(checkNoModifier(*this)))
1041 return failure();
1042 return success();
1043}
1044
1045//===----------------------------------------------------------------------===//
1046// CopyoutOp
1047//===----------------------------------------------------------------------===//
1048LogicalResult acc::CopyoutOp::verify() {
1049 // Test for all clauses this operation can be decomposed from:
1050 if (getDataClause() != acc::DataClause::acc_copyout &&
1051 getDataClause() != acc::DataClause::acc_copyout_zero &&
1052 getDataClause() != acc::DataClause::acc_copy &&
1053 getDataClause() != acc::DataClause::acc_reduction)
1054 return emitError(
1055 "data clause associated with copyout operation must match its intent"
1056 " or specify original clause this operation was decomposed from");
1057 if (!getVar() || !getAccVar())
1058 return emitError("must have both host and device pointers");
1059 if (failed(checkVarAndVarType(*this)))
1060 return failure();
1061 if (failed(checkVarAndAccVar(*this)))
1062 return failure();
1063 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1064 acc::DataClauseModifier::always |
1065 acc::DataClauseModifier::capture)))
1066 return failure();
1067 return success();
1068}
1069
1070bool acc::CopyoutOp::isCopyoutZero() {
1071 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1072 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1073}
1074
1075//===----------------------------------------------------------------------===//
1076// DeleteOp
1077//===----------------------------------------------------------------------===//
1078LogicalResult acc::DeleteOp::verify() {
1079 // Test for all clauses this operation can be decomposed from:
1080 if (getDataClause() != acc::DataClause::acc_delete &&
1081 getDataClause() != acc::DataClause::acc_create &&
1082 getDataClause() != acc::DataClause::acc_create_zero &&
1083 getDataClause() != acc::DataClause::acc_copyin &&
1084 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1085 getDataClause() != acc::DataClause::acc_present &&
1086 getDataClause() != acc::DataClause::acc_no_create &&
1087 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1088 getDataClause() != acc::DataClause::acc_declare_link)
1089 return emitError(
1090 "data clause associated with delete operation must match its intent"
1091 " or specify original clause this operation was decomposed from");
1092 if (!getAccVar())
1093 return emitError("must have device pointer");
1094 // This op is the exit part of copyin and create - thus allow all modifiers
1095 // allowed on either case.
1096 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1097 acc::DataClauseModifier::readonly |
1098 acc::DataClauseModifier::always |
1099 acc::DataClauseModifier::capture)))
1100 return failure();
1101 return success();
1102}
1103
1104//===----------------------------------------------------------------------===//
1105// DetachOp
1106//===----------------------------------------------------------------------===//
1107LogicalResult acc::DetachOp::verify() {
1108 // Test for all clauses this operation can be decomposed from:
1109 if (getDataClause() != acc::DataClause::acc_detach &&
1110 getDataClause() != acc::DataClause::acc_attach)
1111 return emitError(
1112 "data clause associated with detach operation must match its intent"
1113 " or specify original clause this operation was decomposed from");
1114 if (!getAccVar())
1115 return emitError("must have device pointer");
1116 if (failed(checkNoModifier(*this)))
1117 return failure();
1118 return success();
1119}
1120
1121//===----------------------------------------------------------------------===//
1122// HostOp
1123//===----------------------------------------------------------------------===//
1124LogicalResult acc::UpdateHostOp::verify() {
1125 // Test for all clauses this operation can be decomposed from:
1126 if (getDataClause() != acc::DataClause::acc_update_host &&
1127 getDataClause() != acc::DataClause::acc_update_self)
1128 return emitError(
1129 "data clause associated with host operation must match its intent"
1130 " or specify original clause this operation was decomposed from");
1131 if (!getVar() || !getAccVar())
1132 return emitError("must have both host and device pointers");
1133 if (failed(checkVarAndVarType(*this)))
1134 return failure();
1135 if (failed(checkVarAndAccVar(*this)))
1136 return failure();
1137 if (failed(checkNoModifier(*this)))
1138 return failure();
1139 return success();
1140}
1141
1142//===----------------------------------------------------------------------===//
1143// DeviceOp
1144//===----------------------------------------------------------------------===//
1145LogicalResult acc::UpdateDeviceOp::verify() {
1146 // Test for all clauses this operation can be decomposed from:
1147 if (getDataClause() != acc::DataClause::acc_update_device)
1148 return emitError(
1149 "data clause associated with device operation must match its intent"
1150 " or specify original clause this operation was decomposed from");
1151 if (failed(checkVarAndVarType(*this)))
1152 return failure();
1153 if (failed(checkVarAndAccVar(*this)))
1154 return failure();
1155 if (failed(checkNoModifier(*this)))
1156 return failure();
1157 return success();
1158}
1159
1160//===----------------------------------------------------------------------===//
1161// UseDeviceOp
1162//===----------------------------------------------------------------------===//
1163LogicalResult acc::UseDeviceOp::verify() {
1164 // Test for all clauses this operation can be decomposed from:
1165 if (getDataClause() != acc::DataClause::acc_use_device)
1166 return emitError(
1167 "data clause associated with use_device operation must match its intent"
1168 " or specify original clause this operation was decomposed from");
1169 if (failed(checkVarAndVarType(*this)))
1170 return failure();
1171 if (failed(checkVarAndAccVar(*this)))
1172 return failure();
1173 if (failed(checkNoModifier(*this)))
1174 return failure();
1175 return success();
1176}
1177
1178//===----------------------------------------------------------------------===//
1179// CacheOp
1180//===----------------------------------------------------------------------===//
1181LogicalResult acc::CacheOp::verify() {
1182 // Test for all clauses this operation can be decomposed from:
1183 if (getDataClause() != acc::DataClause::acc_cache &&
1184 getDataClause() != acc::DataClause::acc_cache_readonly)
1185 return emitError(
1186 "data clause associated with cache operation must match its intent"
1187 " or specify original clause this operation was decomposed from");
1188 if (failed(checkVarAndVarType(*this)))
1189 return failure();
1190 if (failed(checkVarAndAccVar(*this)))
1191 return failure();
1192 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
1193 return failure();
1194 return success();
1195}
1196
1197bool acc::CacheOp::isCacheReadonly() {
1198 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1199 acc::bitEnumContainsAny(getModifiers(),
1200 acc::DataClauseModifier::readonly);
1201}
1202
1203//===----------------------------------------------------------------------===//
1204// Data entry/exit operations - getEffects implementations
1205//===----------------------------------------------------------------------===//
1206
1207// This function returns true iff the given operation is enclosed
1208// in any ACC_COMPUTE_CONSTRUCT_OPS operation.
1209// It is quite alike acc::getEnclosingComputeOp() utility,
1210// but we cannot use it here.
1212 mlir::Operation *parentOp = op->getParentOp();
1213 while (parentOp) {
1214 if (mlir::isa<ACC_COMPUTE_CONSTRUCT_OPS>(parentOp))
1215 return true;
1216 parentOp = parentOp->getParentOp();
1217 }
1218 return false;
1219}
1220
1221/// Helper to add an effect on an operand, referenced by its mutable range.
1222template <typename EffectTy>
1225 &effects,
1226 MutableOperandRange operand) {
1227 for (unsigned i = 0, e = operand.size(); i < e; ++i)
1228 effects.emplace_back(EffectTy::get(), &operand[i]);
1229}
1230
1231/// Helper to add an effect on a result value.
1232template <typename EffectTy>
1235 &effects,
1236 Value result) {
1237 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
1238}
1239
1240// PrivateOp: accVar result write.
1241void acc::PrivateOp::getEffects(
1243 &effects) {
1244 // If acc.private is enclosed into a compute operation,
1245 // then it denotes the device side privatization, hence
1246 // it does not access the CurrentDeviceIdResource.
1247 if (!isEnclosedIntoComputeOp(getOperation()))
1248 effects.emplace_back(MemoryEffects::Read::get(),
1250 // TODO: should this be MemoryEffects::Allocate?
1252}
1253
1254// FirstprivateOp: var read, accVar result write.
1255void acc::FirstprivateOp::getEffects(
1257 &effects) {
1258 // If acc.firstprivate is enclosed into a compute operation,
1259 // then it denotes the device side privatization, hence
1260 // it does not access the CurrentDeviceIdResource.
1261 if (!isEnclosedIntoComputeOp(getOperation()))
1262 effects.emplace_back(MemoryEffects::Read::get(),
1264 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1266}
1267
1268// ReductionOp: var read, accVar result write.
1269void acc::ReductionOp::getEffects(
1271 &effects) {
1272 // If acc.reduction is enclosed into a compute operation,
1273 // then it denotes the device side reduction, hence
1274 // it does not access the CurrentDeviceIdResource.
1275 if (!isEnclosedIntoComputeOp(getOperation()))
1276 effects.emplace_back(MemoryEffects::Read::get(),
1278 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1280}
1281
1282// DevicePtrOp: RuntimeCounters read.
1283void acc::DevicePtrOp::getEffects(
1285 &effects) {
1286 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1287 effects.emplace_back(MemoryEffects::Read::get(),
1289}
1290
1291// PresentOp: RuntimeCounters read+write.
1292void acc::PresentOp::getEffects(
1294 &effects) {
1295 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1296 effects.emplace_back(MemoryEffects::Write::get(),
1298 effects.emplace_back(MemoryEffects::Read::get(),
1300}
1301
1302// CopyinOp: RuntimeCounters read+write, var read, accVar result write.
1303void acc::CopyinOp::getEffects(
1305 &effects) {
1306 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1307 effects.emplace_back(MemoryEffects::Write::get(),
1309 effects.emplace_back(MemoryEffects::Read::get(),
1311 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1313}
1314
1315// CreateOp: RuntimeCounters read+write, accVar result write.
1316void acc::CreateOp::getEffects(
1318 &effects) {
1319 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1320 effects.emplace_back(MemoryEffects::Write::get(),
1322 effects.emplace_back(MemoryEffects::Read::get(),
1324 // TODO: should this be MemoryEffects::Allocate?
1326}
1327
1328// NoCreateOp: RuntimeCounters read+write.
1329void acc::NoCreateOp::getEffects(
1331 &effects) {
1332 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1333 effects.emplace_back(MemoryEffects::Write::get(),
1335 effects.emplace_back(MemoryEffects::Read::get(),
1337}
1338
1339// AttachOp: RuntimeCounters read+write, var read.
1340void acc::AttachOp::getEffects(
1342 &effects) {
1343 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1344 effects.emplace_back(MemoryEffects::Write::get(),
1346 effects.emplace_back(MemoryEffects::Read::get(),
1348 // TODO: should we also add MemoryEffects::Write?
1349 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1350}
1351
1352// GetDevicePtrOp: RuntimeCounters read.
1353void acc::GetDevicePtrOp::getEffects(
1355 &effects) {
1356 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1357 effects.emplace_back(MemoryEffects::Read::get(),
1359}
1360
1361// UpdateDeviceOp: var read, accVar result write.
1362void acc::UpdateDeviceOp::getEffects(
1364 &effects) {
1365 effects.emplace_back(MemoryEffects::Read::get(),
1367 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1369}
1370
1371// UseDeviceOp: RuntimeCounters read.
1372void acc::UseDeviceOp::getEffects(
1374 &effects) {
1375 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1376 effects.emplace_back(MemoryEffects::Read::get(),
1378}
1379
1380// DeclareDeviceResidentOp: RuntimeCounters write, var read.
1381void acc::DeclareDeviceResidentOp::getEffects(
1383 &effects) {
1384 effects.emplace_back(MemoryEffects::Write::get(),
1386 effects.emplace_back(MemoryEffects::Read::get(),
1388 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1389}
1390
1391// DeclareLinkOp: RuntimeCounters write, var read.
1392void acc::DeclareLinkOp::getEffects(
1394 &effects) {
1395 effects.emplace_back(MemoryEffects::Write::get(),
1397 effects.emplace_back(MemoryEffects::Read::get(),
1399 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1400}
1401
1402// CacheOp: NoMemoryEffect
1403void acc::CacheOp::getEffects(
1405 &effects) {}
1406
1407// CopyoutOp: RuntimeCounters read+write, accVar read, var write.
1408void acc::CopyoutOp::getEffects(
1410 &effects) {
1411 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1412 effects.emplace_back(MemoryEffects::Write::get(),
1414 effects.emplace_back(MemoryEffects::Read::get(),
1416 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1417 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1418}
1419
1420// DeleteOp: RuntimeCounters read+write, accVar read.
1421void acc::DeleteOp::getEffects(
1423 &effects) {
1424 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1425 effects.emplace_back(MemoryEffects::Write::get(),
1427 effects.emplace_back(MemoryEffects::Read::get(),
1429 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1430}
1431
1432// DetachOp: RuntimeCounters read+write, accVar read.
1433void acc::DetachOp::getEffects(
1435 &effects) {
1436 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1437 effects.emplace_back(MemoryEffects::Write::get(),
1439 effects.emplace_back(MemoryEffects::Read::get(),
1441 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1442}
1443
1444// UpdateHostOp: RuntimeCounters read+write, accVar read, var write.
1445void acc::UpdateHostOp::getEffects(
1447 &effects) {
1448 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1449 effects.emplace_back(MemoryEffects::Write::get(),
1451 effects.emplace_back(MemoryEffects::Read::get(),
1453 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1454 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1455}
1456
1457template <typename StructureOp>
1458static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
1459 unsigned nRegions = 1) {
1460
1462 for (unsigned i = 0; i < nRegions; ++i)
1463 regions.push_back(state.addRegion());
1464
1465 for (Region *region : regions)
1466 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
1467 return failure();
1468
1469 return success();
1470}
1471
1473 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1474}
1475
1476namespace {
1477/// Pattern to remove operation without region that have constant false `ifCond`
1478/// and remove the condition from the operation if the `ifCond` is a true
1479/// constant.
1480template <typename OpTy>
1481struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1482 using OpRewritePattern<OpTy>::OpRewritePattern;
1483
1484 LogicalResult matchAndRewrite(OpTy op,
1485 PatternRewriter &rewriter) const override {
1486 // Early return if there is no condition.
1487 Value ifCond = op.getIfCond();
1488 if (!ifCond)
1489 return failure();
1490
1491 IntegerAttr constAttr;
1492 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1493 return failure();
1494 if (constAttr.getInt())
1495 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1496 else
1497 rewriter.eraseOp(op);
1498
1499 return success();
1500 }
1501};
1502
1503/// Replaces the given op with the contents of the given single-block region,
1504/// using the operands of the block terminator to replace operation results.
1505static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1506 Region &region, ValueRange blockArgs = {}) {
1507 assert(region.hasOneBlock() && "expected single-block region");
1508 Block *block = &region.front();
1509 Operation *terminator = block->getTerminator();
1510 ValueRange results = terminator->getOperands();
1511 rewriter.inlineBlockBefore(block, op, blockArgs);
1512 rewriter.replaceOp(op, results);
1513 rewriter.eraseOp(terminator);
1514}
1515
1516/// Pattern to remove operation with region that have constant false `ifCond`
1517/// and remove the condition from the operation if the `ifCond` is constant
1518/// true.
1519template <typename OpTy>
1520struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1521 using OpRewritePattern<OpTy>::OpRewritePattern;
1522
1523 LogicalResult matchAndRewrite(OpTy op,
1524 PatternRewriter &rewriter) const override {
1525 // Early return if there is no condition.
1526 Value ifCond = op.getIfCond();
1527 if (!ifCond)
1528 return failure();
1529
1530 IntegerAttr constAttr;
1531 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1532 return failure();
1533 if (constAttr.getInt())
1534 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1535 else
1536 replaceOpWithRegion(rewriter, op, op.getRegion());
1537
1538 return success();
1539 }
1540};
1541
1542//===----------------------------------------------------------------------===//
1543// Recipe Region Helpers
1544//===----------------------------------------------------------------------===//
1545
1546/// Create and populate an init region for privatization recipes.
1547/// Returns success if the region is populated, failure otherwise.
1548/// Sets needsFree to indicate if the allocated memory requires deallocation.
1549static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1550 Region &initRegion, Type varType,
1551 StringRef varName, ValueRange bounds,
1552 bool &needsFree) {
1553 // Create init block with arguments: original value + bounds
1554 SmallVector<Type> argTypes{varType};
1555 SmallVector<Location> argLocs{loc};
1556 for (Value bound : bounds) {
1557 argTypes.push_back(bound.getType());
1558 argLocs.push_back(loc);
1559 }
1560
1561 Block *initBlock = builder.createBlock(&initRegion);
1562 initBlock->addArguments(argTypes, argLocs);
1563 builder.setInsertionPointToStart(initBlock);
1564
1565 Value privatizedValue;
1566
1567 // Get the block argument that represents the original variable
1568 Value blockArgVar = initBlock->getArgument(0);
1569
1570 // Generate init region body based on variable type
1571 if (isa<MappableType>(varType)) {
1572 auto mappableTy = cast<MappableType>(varType);
1573 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1574 privatizedValue = mappableTy.generatePrivateInit(
1575 builder, loc, typedVar, varName, bounds, {}, needsFree);
1576 if (!privatizedValue)
1577 return failure();
1578 } else {
1579 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1580 auto pointerLikeTy = cast<PointerLikeType>(varType);
1581 // Use PointerLikeType's allocation API with the block argument
1582 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1583 blockArgVar, needsFree);
1584 if (!privatizedValue)
1585 return failure();
1586 }
1587
1588 // Add yield operation to init block
1589 acc::YieldOp::create(builder, loc, privatizedValue);
1590
1591 return success();
1592}
1593
1594/// Create and populate a copy region for firstprivate recipes.
1595/// Returns success if the region is populated, failure otherwise.
1596/// TODO: Handle MappableType - it does not yet have a copy API.
1597static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1598 Region &copyRegion, Type varType,
1599 ValueRange bounds) {
1600 // Create copy block with arguments: original value + privatized value +
1601 // bounds
1602 SmallVector<Type> copyArgTypes{varType, varType};
1603 SmallVector<Location> copyArgLocs{loc, loc};
1604 for (Value bound : bounds) {
1605 copyArgTypes.push_back(bound.getType());
1606 copyArgLocs.push_back(loc);
1607 }
1608
1609 Block *copyBlock = builder.createBlock(&copyRegion);
1610 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1611 builder.setInsertionPointToStart(copyBlock);
1612
1613 bool isMappable = isa<MappableType>(varType);
1614 bool isPointerLike = isa<PointerLikeType>(varType);
1615 // TODO: Handle MappableType - it does not yet have a copy API.
1616 // Otherwise, for now just fallback to pointer-like behavior.
1617 if (isMappable && !isPointerLike)
1618 return failure();
1619
1620 // Generate copy region body based on variable type
1621 if (isPointerLike) {
1622 auto pointerLikeTy = cast<PointerLikeType>(varType);
1623 Value originalArg = copyBlock->getArgument(0);
1624 Value privatizedArg = copyBlock->getArgument(1);
1625
1626 // Generate copy operation using PointerLikeType interface
1627 if (!pointerLikeTy.genCopy(
1628 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1629 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1630 return failure();
1631 }
1632
1633 // Add terminator to copy block
1634 acc::TerminatorOp::create(builder, loc);
1635
1636 return success();
1637}
1638
1639/// Create and populate a destroy region for privatization recipes.
1640/// Returns success if the region is populated, failure otherwise.
1641static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1642 Region &destroyRegion, Type varType,
1643 Value allocRes, ValueRange bounds) {
1644 // Create destroy block with arguments: original value + privatized value +
1645 // bounds
1646 SmallVector<Type> destroyArgTypes{varType, varType};
1647 SmallVector<Location> destroyArgLocs{loc, loc};
1648 for (Value bound : bounds) {
1649 destroyArgTypes.push_back(bound.getType());
1650 destroyArgLocs.push_back(loc);
1651 }
1652
1653 Block *destroyBlock = builder.createBlock(&destroyRegion);
1654 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1655 builder.setInsertionPointToStart(destroyBlock);
1656
1657 auto varToFree =
1658 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1659 if (isa<MappableType>(varType)) {
1660 auto mappableTy = cast<MappableType>(varType);
1661 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds))
1662 return failure();
1663 } else {
1664 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1665 auto pointerLikeTy = cast<PointerLikeType>(varType);
1666 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1667 return failure();
1668 }
1669
1670 acc::TerminatorOp::create(builder, loc);
1671 return success();
1672}
1673
1674} // namespace
1675
1676//===----------------------------------------------------------------------===//
1677// PrivateRecipeOp
1678//===----------------------------------------------------------------------===//
1679
1681 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1682 Type type, bool verifyYield, bool optional = false) {
1683 if (optional && region.empty())
1684 return success();
1685
1686 if (region.empty())
1687 return op->emitOpError() << "expects non-empty " << regionName << " region";
1688 Block &firstBlock = region.front();
1689 if (firstBlock.getNumArguments() < 1 ||
1690 firstBlock.getArgument(0).getType() != type)
1691 return op->emitOpError() << "expects " << regionName
1692 << " region first "
1693 "argument of the "
1694 << regionType << " type";
1695
1696 if (verifyYield) {
1697 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1698 if (yieldOp.getOperands().size() != 1 ||
1699 yieldOp.getOperands().getTypes()[0] != type)
1700 return op->emitOpError() << "expects " << regionName
1701 << " region to "
1702 "yield a value of the "
1703 << regionType << " type";
1704 }
1705 }
1706 return success();
1707}
1708
1709LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1710 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1711 "privatization", "init", getType(),
1712 /*verifyYield=*/false)))
1713 return failure();
1715 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1716 /*verifyYield=*/false, /*optional=*/true)))
1717 return failure();
1718 return success();
1719}
1720
1721std::optional<PrivateRecipeOp>
1722PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1723 StringRef recipeName, Type varType,
1724 StringRef varName, ValueRange bounds) {
1725 // First, validate that we can handle this variable type
1726 bool isMappable = isa<MappableType>(varType);
1727 bool isPointerLike = isa<PointerLikeType>(varType);
1728
1729 // Unsupported type
1730 if (!isMappable && !isPointerLike)
1731 return std::nullopt;
1732
1733 OpBuilder::InsertionGuard guard(builder);
1734
1735 // Create the recipe operation first so regions have proper parent context
1736 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1737
1738 // Populate the init region
1739 bool needsFree = false;
1740 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1741 varName, bounds, needsFree))) {
1742 recipe.erase();
1743 return std::nullopt;
1744 }
1745
1746 // Only create destroy region if the allocation needs deallocation
1747 if (needsFree) {
1748 // Extract the allocated value from the init block's yield operation
1749 auto yieldOp =
1750 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1751 Value allocRes = yieldOp.getOperand(0);
1752
1753 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1754 varType, allocRes, bounds))) {
1755 recipe.erase();
1756 return std::nullopt;
1757 }
1758 }
1759
1760 return recipe;
1761}
1762
1763std::optional<PrivateRecipeOp>
1764PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1765 StringRef recipeName,
1766 FirstprivateRecipeOp firstprivRecipe) {
1767 // Create the private.recipe op with the same type as the firstprivate.recipe.
1768 OpBuilder::InsertionGuard guard(builder);
1769 auto varType = firstprivRecipe.getType();
1770 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1771
1772 // Clone the init region
1773 IRMapping mapping;
1774 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1775
1776 // Clone destroy region if the firstprivate.recipe has one.
1777 if (!firstprivRecipe.getDestroyRegion().empty()) {
1778 IRMapping mapping;
1779 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1780 mapping);
1781 }
1782 return recipe;
1783}
1784
1785//===----------------------------------------------------------------------===//
1786// FirstprivateRecipeOp
1787//===----------------------------------------------------------------------===//
1788
1789LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1790 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1791 "privatization", "init", getType(),
1792 /*verifyYield=*/false)))
1793 return failure();
1794
1795 if (getCopyRegion().empty())
1796 return emitOpError() << "expects non-empty copy region";
1797
1798 Block &firstBlock = getCopyRegion().front();
1799 if (firstBlock.getNumArguments() < 2 ||
1800 firstBlock.getArgument(0).getType() != getType())
1801 return emitOpError() << "expects copy region with two arguments of the "
1802 "privatization type";
1803
1804 if (getDestroyRegion().empty())
1805 return success();
1806
1807 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1808 "privatization", "destroy",
1809 getType(), /*verifyYield=*/false)))
1810 return failure();
1811
1812 return success();
1813}
1814
1815std::optional<FirstprivateRecipeOp>
1816FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1817 StringRef recipeName, Type varType,
1818 StringRef varName, ValueRange bounds) {
1819 // First, validate that we can handle this variable type
1820 bool isMappable = isa<MappableType>(varType);
1821 bool isPointerLike = isa<PointerLikeType>(varType);
1822
1823 // Unsupported type
1824 if (!isMappable && !isPointerLike)
1825 return std::nullopt;
1826
1827 OpBuilder::InsertionGuard guard(builder);
1828
1829 // Create the recipe operation first so regions have proper parent context
1830 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1831
1832 // Populate the init region
1833 bool needsFree = false;
1834 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1835 varName, bounds, needsFree))) {
1836 recipe.erase();
1837 return std::nullopt;
1838 }
1839
1840 // Populate the copy region
1841 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1842 bounds))) {
1843 recipe.erase();
1844 return std::nullopt;
1845 }
1846
1847 // Only create destroy region if the allocation needs deallocation
1848 if (needsFree) {
1849 // Extract the allocated value from the init block's yield operation
1850 auto yieldOp =
1851 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1852 Value allocRes = yieldOp.getOperand(0);
1853
1854 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1855 varType, allocRes, bounds))) {
1856 recipe.erase();
1857 return std::nullopt;
1858 }
1859 }
1860
1861 return recipe;
1862}
1863
1864//===----------------------------------------------------------------------===//
1865// ReductionRecipeOp
1866//===----------------------------------------------------------------------===//
1867
1868LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1869 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1870 "init", getType(),
1871 /*verifyYield=*/false)))
1872 return failure();
1873
1874 if (getCombinerRegion().empty())
1875 return emitOpError() << "expects non-empty combiner region";
1876
1877 Block &reductionBlock = getCombinerRegion().front();
1878 if (reductionBlock.getNumArguments() < 2 ||
1879 reductionBlock.getArgument(0).getType() != getType() ||
1880 reductionBlock.getArgument(1).getType() != getType())
1881 return emitOpError() << "expects combiner region with the first two "
1882 << "arguments of the reduction type";
1883
1884 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1885 if (yieldOp.getOperands().size() != 1 ||
1886 yieldOp.getOperands().getTypes()[0] != getType())
1887 return emitOpError() << "expects combiner region to yield a value "
1888 "of the reduction type";
1889 }
1890
1891 return success();
1892}
1893
1894//===----------------------------------------------------------------------===//
1895// ParallelOp
1896//===----------------------------------------------------------------------===//
1897
1898/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1899template <typename Op>
1900static LogicalResult checkDataOperands(Op op,
1901 const mlir::ValueRange &operands) {
1902 for (mlir::Value operand : operands)
1903 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1904 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1905 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1906 operand.getDefiningOp()))
1907 return op.emitError(
1908 "expect data entry/exit operation or acc.getdeviceptr "
1909 "as defining op");
1910 return success();
1911}
1912
1913template <typename OpT, typename RecipeOpT>
1914static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
1915 const mlir::ValueRange &operands,
1916 llvm::StringRef operandName) {
1918 for (mlir::Value operand : operands) {
1919 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1920 return accConstructOp->emitOpError()
1921 << "expected " << operandName << " as defining op";
1922 if (!set.insert(operand).second)
1923 return accConstructOp->emitOpError()
1924 << operandName << " operand appears more than once";
1925 }
1926 return success();
1927}
1928
1929unsigned ParallelOp::getNumDataOperands() {
1930 return getReductionOperands().size() + getPrivateOperands().size() +
1931 getFirstprivateOperands().size() + getDataClauseOperands().size();
1932}
1933
1934Value ParallelOp::getDataOperand(unsigned i) {
1935 unsigned numOptional = getAsyncOperands().size();
1936 numOptional += getNumGangs().size();
1937 numOptional += getNumWorkers().size();
1938 numOptional += getVectorLength().size();
1939 numOptional += getIfCond() ? 1 : 0;
1940 numOptional += getSelfCond() ? 1 : 0;
1941 return getOperand(getWaitOperands().size() + numOptional + i);
1942}
1943
1944template <typename Op>
1945static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1946 ArrayAttr deviceTypes,
1947 llvm::StringRef keyword) {
1948 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1949 return op.emitOpError() << keyword << " operands count must match "
1950 << keyword << " device_type count";
1951 return success();
1952}
1953
1954template <typename Op>
1956 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1957 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1958 std::size_t numOperandsInSegments = 0;
1959 std::size_t nbOfSegments = 0;
1960
1961 if (segments) {
1962 for (auto segCount : segments.asArrayRef()) {
1963 if (maxInSegment != 0 && segCount > maxInSegment)
1964 return op.emitOpError() << keyword << " expects a maximum of "
1965 << maxInSegment << " values per segment";
1966 numOperandsInSegments += segCount;
1967 ++nbOfSegments;
1968 }
1969 }
1970
1971 if ((numOperandsInSegments != operands.size()) ||
1972 (!deviceTypes && !operands.empty()))
1973 return op.emitOpError()
1974 << keyword << " operand count does not match count in segments";
1975 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1976 return op.emitOpError()
1977 << keyword << " segment count does not match device_type count";
1978 return success();
1979}
1980
1981LogicalResult acc::ParallelOp::verify() {
1982 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
1983 mlir::acc::PrivateRecipeOp>(
1984 *this, getPrivateOperands(), "private")))
1985 return failure();
1986 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
1987 mlir::acc::FirstprivateRecipeOp>(
1988 *this, getFirstprivateOperands(), "firstprivate")))
1989 return failure();
1990 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
1991 mlir::acc::ReductionRecipeOp>(
1992 *this, getReductionOperands(), "reduction")))
1993 return failure();
1994
1996 *this, getNumGangs(), getNumGangsSegmentsAttr(),
1997 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1998 return failure();
1999
2001 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2002 getWaitOperandsDeviceTypeAttr(), "wait")))
2003 return failure();
2004
2005 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2006 getNumWorkersDeviceTypeAttr(),
2007 "num_workers")))
2008 return failure();
2009
2010 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2011 getVectorLengthDeviceTypeAttr(),
2012 "vector_length")))
2013 return failure();
2014
2016 getAsyncOperandsDeviceTypeAttr(),
2017 "async")))
2018 return failure();
2019
2021 return failure();
2022
2023 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
2024}
2025
2026static mlir::Value
2027getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
2029 mlir::acc::DeviceType deviceType) {
2030 if (!arrayAttr)
2031 return {};
2032 if (auto pos = findSegment(*arrayAttr, deviceType))
2033 return range[*pos];
2034 return {};
2035}
2036
2037bool acc::ParallelOp::hasAsyncOnly() {
2038 return hasAsyncOnly(mlir::acc::DeviceType::None);
2039}
2040
2041bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2042 return hasDeviceType(getAsyncOnly(), deviceType);
2043}
2044
2045mlir::Value acc::ParallelOp::getAsyncValue() {
2046 return getAsyncValue(mlir::acc::DeviceType::None);
2047}
2048
2049mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2051 getAsyncOperands(), deviceType);
2052}
2053
2054mlir::Value acc::ParallelOp::getNumWorkersValue() {
2055 return getNumWorkersValue(mlir::acc::DeviceType::None);
2056}
2057
2059acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2060 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2061 deviceType);
2062}
2063
2064mlir::Value acc::ParallelOp::getVectorLengthValue() {
2065 return getVectorLengthValue(mlir::acc::DeviceType::None);
2066}
2067
2069acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2070 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2071 getVectorLength(), deviceType);
2072}
2073
2074mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
2075 return getNumGangsValues(mlir::acc::DeviceType::None);
2076}
2077
2079ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2080 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2081 getNumGangsSegments(), deviceType);
2082}
2083
2084bool acc::ParallelOp::hasWaitOnly() {
2085 return hasWaitOnly(mlir::acc::DeviceType::None);
2086}
2087
2088bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2089 return hasDeviceType(getWaitOnly(), deviceType);
2090}
2091
2092mlir::Operation::operand_range ParallelOp::getWaitValues() {
2093 return getWaitValues(mlir::acc::DeviceType::None);
2094}
2095
2097ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2099 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2100 getHasWaitDevnum(), deviceType);
2101}
2102
2103mlir::Value ParallelOp::getWaitDevnum() {
2104 return getWaitDevnum(mlir::acc::DeviceType::None);
2105}
2106
2107mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2108 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2109 getWaitOperandsSegments(), getHasWaitDevnum(),
2110 deviceType);
2111}
2112
2113void ParallelOp::build(mlir::OpBuilder &odsBuilder,
2114 mlir::OperationState &odsState,
2115 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
2116 mlir::ValueRange vectorLength,
2117 mlir::ValueRange asyncOperands,
2118 mlir::ValueRange waitOperands, mlir::Value ifCond,
2119 mlir::Value selfCond, mlir::ValueRange reductionOperands,
2120 mlir::ValueRange gangPrivateOperands,
2121 mlir::ValueRange gangFirstPrivateOperands,
2122 mlir::ValueRange dataClauseOperands) {
2123 ParallelOp::build(
2124 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
2125 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
2126 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
2127 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
2128 /*numGangsDeviceType=*/nullptr, numWorkers,
2129 /*numWorkersDeviceType=*/nullptr, vectorLength,
2130 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
2131 /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
2132 gangFirstPrivateOperands, dataClauseOperands,
2133 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
2134}
2135
2136void acc::ParallelOp::addNumWorkersOperand(
2137 MLIRContext *context, mlir::Value newValue,
2138 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2139 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2140 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2141 getNumWorkersMutable()));
2142}
2143void acc::ParallelOp::addVectorLengthOperand(
2144 MLIRContext *context, mlir::Value newValue,
2145 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2146 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2147 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2148 getVectorLengthMutable()));
2149}
2150
2151void acc::ParallelOp::addAsyncOnly(
2152 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2153 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2154 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2155}
2156
2157void acc::ParallelOp::addAsyncOperand(
2158 MLIRContext *context, mlir::Value newValue,
2159 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2160 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2161 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2162 getAsyncOperandsMutable()));
2163}
2164
2165void acc::ParallelOp::addNumGangsOperands(
2166 MLIRContext *context, mlir::ValueRange newValues,
2167 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2169 if (getNumGangsSegments())
2170 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2171
2172 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2173 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2174 getNumGangsMutable(), segments));
2175
2176 setNumGangsSegments(segments);
2177}
2178void acc::ParallelOp::addWaitOnly(
2179 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2180 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2181 effectiveDeviceTypes));
2182}
2183void acc::ParallelOp::addWaitOperands(
2184 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2185 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2186
2188 if (getWaitOperandsSegments())
2189 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2190
2191 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2192 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2193 getWaitOperandsMutable(), segments));
2194 setWaitOperandsSegments(segments);
2195
2197 if (getHasWaitDevnumAttr())
2198 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2199 hasDevnums.insert(
2200 hasDevnums.end(),
2201 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2202 mlir::BoolAttr::get(context, hasDevnum));
2203 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2204}
2205
2206void acc::ParallelOp::addPrivatization(MLIRContext *context,
2207 mlir::acc::PrivateOp op,
2208 mlir::acc::PrivateRecipeOp recipe) {
2209 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2210 getPrivateOperandsMutable().append(op.getResult());
2211}
2212
2213void acc::ParallelOp::addFirstPrivatization(
2214 MLIRContext *context, mlir::acc::FirstprivateOp op,
2215 mlir::acc::FirstprivateRecipeOp recipe) {
2216 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2217 getFirstprivateOperandsMutable().append(op.getResult());
2218}
2219
2220void acc::ParallelOp::addReduction(MLIRContext *context,
2221 mlir::acc::ReductionOp op,
2222 mlir::acc::ReductionRecipeOp recipe) {
2223 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2224 getReductionOperandsMutable().append(op.getResult());
2225}
2226
2227static ParseResult parseNumGangs(
2228 mlir::OpAsmParser &parser,
2230 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2231 mlir::DenseI32ArrayAttr &segments) {
2234
2235 do {
2236 if (failed(parser.parseLBrace()))
2237 return failure();
2238
2239 int32_t crtOperandsSize = operands.size();
2240 if (failed(parser.parseCommaSeparatedList(
2242 if (parser.parseOperand(operands.emplace_back()) ||
2243 parser.parseColonType(types.emplace_back()))
2244 return failure();
2245 return success();
2246 })))
2247 return failure();
2248 seg.push_back(operands.size() - crtOperandsSize);
2249
2250 if (failed(parser.parseRBrace()))
2251 return failure();
2252
2253 if (succeeded(parser.parseOptionalLSquare())) {
2254 if (parser.parseAttribute(attributes.emplace_back()) ||
2255 parser.parseRSquare())
2256 return failure();
2257 } else {
2258 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2259 parser.getContext(), mlir::acc::DeviceType::None));
2260 }
2261 } while (succeeded(parser.parseOptionalComma()));
2262
2263 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2264 attributes.end());
2265 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2266 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2267
2268 return success();
2269}
2270
2272 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2273 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2274 p << " [" << attr << "]";
2275}
2276
2278 mlir::OperandRange operands, mlir::TypeRange types,
2279 std::optional<mlir::ArrayAttr> deviceTypes,
2280 std::optional<mlir::DenseI32ArrayAttr> segments) {
2281 unsigned opIdx = 0;
2282 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2283 p << "{";
2284 llvm::interleaveComma(
2285 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2286 p << operands[opIdx] << " : " << operands[opIdx].getType();
2287 ++opIdx;
2288 });
2289 p << "}";
2290 printSingleDeviceType(p, it.value());
2291 });
2292}
2293
2295 mlir::OpAsmParser &parser,
2297 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2298 mlir::DenseI32ArrayAttr &segments) {
2301
2302 do {
2303 if (failed(parser.parseLBrace()))
2304 return failure();
2305
2306 int32_t crtOperandsSize = operands.size();
2307
2308 if (failed(parser.parseCommaSeparatedList(
2310 if (parser.parseOperand(operands.emplace_back()) ||
2311 parser.parseColonType(types.emplace_back()))
2312 return failure();
2313 return success();
2314 })))
2315 return failure();
2316
2317 seg.push_back(operands.size() - crtOperandsSize);
2318
2319 if (failed(parser.parseRBrace()))
2320 return failure();
2321
2322 if (succeeded(parser.parseOptionalLSquare())) {
2323 if (parser.parseAttribute(attributes.emplace_back()) ||
2324 parser.parseRSquare())
2325 return failure();
2326 } else {
2327 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2328 parser.getContext(), mlir::acc::DeviceType::None));
2329 }
2330 } while (succeeded(parser.parseOptionalComma()));
2331
2332 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2333 attributes.end());
2334 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2335 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2336
2337 return success();
2338}
2339
2342 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2343 std::optional<mlir::DenseI32ArrayAttr> segments) {
2344 unsigned opIdx = 0;
2345 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2346 p << "{";
2347 llvm::interleaveComma(
2348 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2349 p << operands[opIdx] << " : " << operands[opIdx].getType();
2350 ++opIdx;
2351 });
2352 p << "}";
2353 printSingleDeviceType(p, it.value());
2354 });
2355}
2356
2357static ParseResult parseWaitClause(
2358 mlir::OpAsmParser &parser,
2360 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2361 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2362 mlir::ArrayAttr &keywordOnly) {
2363 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2365
2366 bool needCommaBeforeOperands = false;
2367
2368 // Keyword only
2369 if (failed(parser.parseOptionalLParen())) {
2370 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2371 parser.getContext(), mlir::acc::DeviceType::None));
2372 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2373 return success();
2374 }
2375
2376 // Parse keyword only attributes
2377 if (succeeded(parser.parseOptionalLSquare())) {
2378 if (failed(parser.parseCommaSeparatedList([&]() {
2379 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2380 return failure();
2381 return success();
2382 })))
2383 return failure();
2384 if (parser.parseRSquare())
2385 return failure();
2386 needCommaBeforeOperands = true;
2387 }
2388
2389 if (needCommaBeforeOperands && failed(parser.parseComma()))
2390 return failure();
2391
2392 do {
2393 if (failed(parser.parseLBrace()))
2394 return failure();
2395
2396 int32_t crtOperandsSize = operands.size();
2397
2398 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2399 if (failed(parser.parseColon()))
2400 return failure();
2401 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2402 } else {
2403 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2404 }
2405
2406 if (failed(parser.parseCommaSeparatedList(
2408 if (parser.parseOperand(operands.emplace_back()) ||
2409 parser.parseColonType(types.emplace_back()))
2410 return failure();
2411 return success();
2412 })))
2413 return failure();
2414
2415 seg.push_back(operands.size() - crtOperandsSize);
2416
2417 if (failed(parser.parseRBrace()))
2418 return failure();
2419
2420 if (succeeded(parser.parseOptionalLSquare())) {
2421 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2422 parser.parseRSquare())
2423 return failure();
2424 } else {
2425 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2426 parser.getContext(), mlir::acc::DeviceType::None));
2427 }
2428 } while (succeeded(parser.parseOptionalComma()));
2429
2430 if (failed(parser.parseRParen()))
2431 return failure();
2432
2433 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2434 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2435 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2436 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2437
2438 return success();
2439}
2440
2441static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2442 if (!hasDeviceTypeValues(attrs))
2443 return false;
2444 if (attrs->size() != 1)
2445 return false;
2446 if (auto deviceTypeAttr =
2447 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2448 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2449 return false;
2450}
2451
2453 mlir::OperandRange operands, mlir::TypeRange types,
2454 std::optional<mlir::ArrayAttr> deviceTypes,
2455 std::optional<mlir::DenseI32ArrayAttr> segments,
2456 std::optional<mlir::ArrayAttr> hasDevNum,
2457 std::optional<mlir::ArrayAttr> keywordOnly) {
2458
2459 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2460 return;
2461
2462 p << "(";
2463
2464 printDeviceTypes(p, keywordOnly);
2465 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2466 p << ", ";
2467
2468 if (hasDeviceTypeValues(deviceTypes)) {
2469 unsigned opIdx = 0;
2470 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2471 p << "{";
2472 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2473 if (boolAttr && boolAttr.getValue())
2474 p << "devnum: ";
2475 llvm::interleaveComma(
2476 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2477 p << operands[opIdx] << " : " << operands[opIdx].getType();
2478 ++opIdx;
2479 });
2480 p << "}";
2481 printSingleDeviceType(p, it.value());
2482 });
2483 }
2484
2485 p << ")";
2486}
2487
2488static ParseResult parseDeviceTypeOperands(
2489 mlir::OpAsmParser &parser,
2491 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2493 if (failed(parser.parseCommaSeparatedList([&]() {
2494 if (parser.parseOperand(operands.emplace_back()) ||
2495 parser.parseColonType(types.emplace_back()))
2496 return failure();
2497 if (succeeded(parser.parseOptionalLSquare())) {
2498 if (parser.parseAttribute(attributes.emplace_back()) ||
2499 parser.parseRSquare())
2500 return failure();
2501 } else {
2502 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2503 parser.getContext(), mlir::acc::DeviceType::None));
2504 }
2505 return success();
2506 })))
2507 return failure();
2508 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2509 attributes.end());
2510 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2511 return success();
2512}
2513
2514static void
2516 mlir::OperandRange operands, mlir::TypeRange types,
2517 std::optional<mlir::ArrayAttr> deviceTypes) {
2518 if (!hasDeviceTypeValues(deviceTypes))
2519 return;
2520 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2521 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2522 printSingleDeviceType(p, std::get<0>(it));
2523 });
2524}
2525
2527 mlir::OpAsmParser &parser,
2529 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2530 mlir::ArrayAttr &keywordOnlyDeviceType) {
2531
2532 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2533 bool needCommaBeforeOperands = false;
2534
2535 if (failed(parser.parseOptionalLParen())) {
2536 // Keyword only
2537 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2538 parser.getContext(), mlir::acc::DeviceType::None));
2539 keywordOnlyDeviceType =
2540 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2541 return success();
2542 }
2543
2544 // Parse keyword only attributes
2545 if (succeeded(parser.parseOptionalLSquare())) {
2546 // Parse keyword only attributes
2547 if (failed(parser.parseCommaSeparatedList([&]() {
2548 if (parser.parseAttribute(
2549 keywordOnlyDeviceTypeAttributes.emplace_back()))
2550 return failure();
2551 return success();
2552 })))
2553 return failure();
2554 if (parser.parseRSquare())
2555 return failure();
2556 needCommaBeforeOperands = true;
2557 }
2558
2559 if (needCommaBeforeOperands && failed(parser.parseComma()))
2560 return failure();
2561
2563 if (failed(parser.parseCommaSeparatedList([&]() {
2564 if (parser.parseOperand(operands.emplace_back()) ||
2565 parser.parseColonType(types.emplace_back()))
2566 return failure();
2567 if (succeeded(parser.parseOptionalLSquare())) {
2568 if (parser.parseAttribute(attributes.emplace_back()) ||
2569 parser.parseRSquare())
2570 return failure();
2571 } else {
2572 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2573 parser.getContext(), mlir::acc::DeviceType::None));
2574 }
2575 return success();
2576 })))
2577 return failure();
2578
2579 if (failed(parser.parseRParen()))
2580 return failure();
2581
2582 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2583 attributes.end());
2584 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2585 return success();
2586}
2587
2590 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2591 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2592
2593 if (operands.begin() == operands.end() &&
2594 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2595 return;
2596 }
2597
2598 p << "(";
2599 printDeviceTypes(p, keywordOnlyDeviceTypes);
2600 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2601 hasDeviceTypeValues(deviceTypes))
2602 p << ", ";
2603 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2604 p << ")";
2605}
2606
2608 mlir::OpAsmParser &parser,
2609 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2610 mlir::Type &operandType, mlir::UnitAttr &attr) {
2611 // Keyword only
2612 if (failed(parser.parseOptionalLParen())) {
2613 attr = mlir::UnitAttr::get(parser.getContext());
2614 return success();
2615 }
2616
2618 if (failed(parser.parseOperand(op)))
2619 return failure();
2620 operand = op;
2621 if (failed(parser.parseColon()))
2622 return failure();
2623 if (failed(parser.parseType(operandType)))
2624 return failure();
2625 if (failed(parser.parseRParen()))
2626 return failure();
2627
2628 return success();
2629}
2630
2632 mlir::Operation *op,
2633 std::optional<mlir::Value> operand,
2634 mlir::Type operandType,
2635 mlir::UnitAttr attr) {
2636 if (attr)
2637 return;
2638
2639 p << "(";
2640 p.printOperand(*operand);
2641 p << " : ";
2642 p.printType(operandType);
2643 p << ")";
2644}
2645
2647 mlir::OpAsmParser &parser,
2649 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2650 // Keyword only
2651 if (failed(parser.parseOptionalLParen())) {
2652 attr = mlir::UnitAttr::get(parser.getContext());
2653 return success();
2654 }
2655
2656 if (failed(parser.parseCommaSeparatedList([&]() {
2657 if (parser.parseOperand(operands.emplace_back()))
2658 return failure();
2659 return success();
2660 })))
2661 return failure();
2662 if (failed(parser.parseColon()))
2663 return failure();
2664 if (failed(parser.parseCommaSeparatedList([&]() {
2665 if (parser.parseType(types.emplace_back()))
2666 return failure();
2667 return success();
2668 })))
2669 return failure();
2670 if (failed(parser.parseRParen()))
2671 return failure();
2672
2673 return success();
2674}
2675
2677 mlir::Operation *op,
2678 mlir::OperandRange operands,
2679 mlir::TypeRange types,
2680 mlir::UnitAttr attr) {
2681 if (attr)
2682 return;
2683
2684 p << "(";
2685 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2686 p << " : ";
2687 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2688 p << ")";
2689}
2690
2691static ParseResult
2693 mlir::acc::CombinedConstructsTypeAttr &attr) {
2694 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2695 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2696 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2697 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2698 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2699 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2700 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2701 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2702 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2703 } else {
2704 parser.emitError(parser.getCurrentLocation(),
2705 "expected compute construct name");
2706 return failure();
2707 }
2708 return success();
2709}
2710
2711static void
2713 mlir::acc::CombinedConstructsTypeAttr attr) {
2714 if (attr) {
2715 switch (attr.getValue()) {
2716 case mlir::acc::CombinedConstructsType::KernelsLoop:
2717 p << "kernels";
2718 break;
2719 case mlir::acc::CombinedConstructsType::ParallelLoop:
2720 p << "parallel";
2721 break;
2722 case mlir::acc::CombinedConstructsType::SerialLoop:
2723 p << "serial";
2724 break;
2725 };
2726 }
2727}
2728
2729//===----------------------------------------------------------------------===//
2730// SerialOp
2731//===----------------------------------------------------------------------===//
2732
2733unsigned SerialOp::getNumDataOperands() {
2734 return getReductionOperands().size() + getPrivateOperands().size() +
2735 getFirstprivateOperands().size() + getDataClauseOperands().size();
2736}
2737
2738Value SerialOp::getDataOperand(unsigned i) {
2739 unsigned numOptional = getAsyncOperands().size();
2740 numOptional += getIfCond() ? 1 : 0;
2741 numOptional += getSelfCond() ? 1 : 0;
2742 return getOperand(getWaitOperands().size() + numOptional + i);
2743}
2744
2745bool acc::SerialOp::hasAsyncOnly() {
2746 return hasAsyncOnly(mlir::acc::DeviceType::None);
2747}
2748
2749bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2750 return hasDeviceType(getAsyncOnly(), deviceType);
2751}
2752
2753mlir::Value acc::SerialOp::getAsyncValue() {
2754 return getAsyncValue(mlir::acc::DeviceType::None);
2755}
2756
2757mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2759 getAsyncOperands(), deviceType);
2760}
2761
2762bool acc::SerialOp::hasWaitOnly() {
2763 return hasWaitOnly(mlir::acc::DeviceType::None);
2764}
2765
2766bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2767 return hasDeviceType(getWaitOnly(), deviceType);
2768}
2769
2770mlir::Operation::operand_range SerialOp::getWaitValues() {
2771 return getWaitValues(mlir::acc::DeviceType::None);
2772}
2773
2775SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2777 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2778 getHasWaitDevnum(), deviceType);
2779}
2780
2781mlir::Value SerialOp::getWaitDevnum() {
2782 return getWaitDevnum(mlir::acc::DeviceType::None);
2783}
2784
2785mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2786 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2787 getWaitOperandsSegments(), getHasWaitDevnum(),
2788 deviceType);
2789}
2790
2791LogicalResult acc::SerialOp::verify() {
2792 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2793 mlir::acc::PrivateRecipeOp>(
2794 *this, getPrivateOperands(), "private")))
2795 return failure();
2796 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2797 mlir::acc::FirstprivateRecipeOp>(
2798 *this, getFirstprivateOperands(), "firstprivate")))
2799 return failure();
2800 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2801 mlir::acc::ReductionRecipeOp>(
2802 *this, getReductionOperands(), "reduction")))
2803 return failure();
2804
2806 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2807 getWaitOperandsDeviceTypeAttr(), "wait")))
2808 return failure();
2809
2811 getAsyncOperandsDeviceTypeAttr(),
2812 "async")))
2813 return failure();
2814
2816 return failure();
2817
2818 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2819}
2820
2821void acc::SerialOp::addAsyncOnly(
2822 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2823 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2824 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2825}
2826
2827void acc::SerialOp::addAsyncOperand(
2828 MLIRContext *context, mlir::Value newValue,
2829 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2830 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2831 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2832 getAsyncOperandsMutable()));
2833}
2834
2835void acc::SerialOp::addWaitOnly(
2836 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2837 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2838 effectiveDeviceTypes));
2839}
2840void acc::SerialOp::addWaitOperands(
2841 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2842 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2843
2845 if (getWaitOperandsSegments())
2846 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2847
2848 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2849 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2850 getWaitOperandsMutable(), segments));
2851 setWaitOperandsSegments(segments);
2852
2854 if (getHasWaitDevnumAttr())
2855 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2856 hasDevnums.insert(
2857 hasDevnums.end(),
2858 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2859 mlir::BoolAttr::get(context, hasDevnum));
2860 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2861}
2862
2863void acc::SerialOp::addPrivatization(MLIRContext *context,
2864 mlir::acc::PrivateOp op,
2865 mlir::acc::PrivateRecipeOp recipe) {
2866 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2867 getPrivateOperandsMutable().append(op.getResult());
2868}
2869
2870void acc::SerialOp::addFirstPrivatization(
2871 MLIRContext *context, mlir::acc::FirstprivateOp op,
2872 mlir::acc::FirstprivateRecipeOp recipe) {
2873 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2874 getFirstprivateOperandsMutable().append(op.getResult());
2875}
2876
2877void acc::SerialOp::addReduction(MLIRContext *context,
2878 mlir::acc::ReductionOp op,
2879 mlir::acc::ReductionRecipeOp recipe) {
2880 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2881 getReductionOperandsMutable().append(op.getResult());
2882}
2883
2884//===----------------------------------------------------------------------===//
2885// KernelsOp
2886//===----------------------------------------------------------------------===//
2887
2888unsigned KernelsOp::getNumDataOperands() {
2889 return getDataClauseOperands().size();
2890}
2891
2892Value KernelsOp::getDataOperand(unsigned i) {
2893 unsigned numOptional = getAsyncOperands().size();
2894 numOptional += getWaitOperands().size();
2895 numOptional += getNumGangs().size();
2896 numOptional += getNumWorkers().size();
2897 numOptional += getVectorLength().size();
2898 numOptional += getIfCond() ? 1 : 0;
2899 numOptional += getSelfCond() ? 1 : 0;
2900 return getOperand(numOptional + i);
2901}
2902
2903bool acc::KernelsOp::hasAsyncOnly() {
2904 return hasAsyncOnly(mlir::acc::DeviceType::None);
2905}
2906
2907bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2908 return hasDeviceType(getAsyncOnly(), deviceType);
2909}
2910
2911mlir::Value acc::KernelsOp::getAsyncValue() {
2912 return getAsyncValue(mlir::acc::DeviceType::None);
2913}
2914
2915mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2917 getAsyncOperands(), deviceType);
2918}
2919
2920mlir::Value acc::KernelsOp::getNumWorkersValue() {
2921 return getNumWorkersValue(mlir::acc::DeviceType::None);
2922}
2923
2925acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2926 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2927 deviceType);
2928}
2929
2930mlir::Value acc::KernelsOp::getVectorLengthValue() {
2931 return getVectorLengthValue(mlir::acc::DeviceType::None);
2932}
2933
2935acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2936 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2937 getVectorLength(), deviceType);
2938}
2939
2940mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2941 return getNumGangsValues(mlir::acc::DeviceType::None);
2942}
2943
2945KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2946 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2947 getNumGangsSegments(), deviceType);
2948}
2949
2950bool acc::KernelsOp::hasWaitOnly() {
2951 return hasWaitOnly(mlir::acc::DeviceType::None);
2952}
2953
2954bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2955 return hasDeviceType(getWaitOnly(), deviceType);
2956}
2957
2958mlir::Operation::operand_range KernelsOp::getWaitValues() {
2959 return getWaitValues(mlir::acc::DeviceType::None);
2960}
2961
2963KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2965 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2966 getHasWaitDevnum(), deviceType);
2967}
2968
2969mlir::Value KernelsOp::getWaitDevnum() {
2970 return getWaitDevnum(mlir::acc::DeviceType::None);
2971}
2972
2973mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2974 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2975 getWaitOperandsSegments(), getHasWaitDevnum(),
2976 deviceType);
2977}
2978
2979LogicalResult acc::KernelsOp::verify() {
2981 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2982 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2983 return failure();
2984
2986 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2987 getWaitOperandsDeviceTypeAttr(), "wait")))
2988 return failure();
2989
2990 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2991 getNumWorkersDeviceTypeAttr(),
2992 "num_workers")))
2993 return failure();
2994
2995 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2996 getVectorLengthDeviceTypeAttr(),
2997 "vector_length")))
2998 return failure();
2999
3001 getAsyncOperandsDeviceTypeAttr(),
3002 "async")))
3003 return failure();
3004
3006 return failure();
3007
3008 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
3009}
3010
3011void acc::KernelsOp::addPrivatization(MLIRContext *context,
3012 mlir::acc::PrivateOp op,
3013 mlir::acc::PrivateRecipeOp recipe) {
3014 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3015 getPrivateOperandsMutable().append(op.getResult());
3016}
3017
3018void acc::KernelsOp::addFirstPrivatization(
3019 MLIRContext *context, mlir::acc::FirstprivateOp op,
3020 mlir::acc::FirstprivateRecipeOp recipe) {
3021 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3022 getFirstprivateOperandsMutable().append(op.getResult());
3023}
3024
3025void acc::KernelsOp::addReduction(MLIRContext *context,
3026 mlir::acc::ReductionOp op,
3027 mlir::acc::ReductionRecipeOp recipe) {
3028 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3029 getReductionOperandsMutable().append(op.getResult());
3030}
3031
3032void acc::KernelsOp::addNumWorkersOperand(
3033 MLIRContext *context, mlir::Value newValue,
3034 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3035 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3036 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3037 getNumWorkersMutable()));
3038}
3039
3040void acc::KernelsOp::addVectorLengthOperand(
3041 MLIRContext *context, mlir::Value newValue,
3042 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3043 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3044 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3045 getVectorLengthMutable()));
3046}
3047void acc::KernelsOp::addAsyncOnly(
3048 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3049 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3050 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3051}
3052
3053void acc::KernelsOp::addAsyncOperand(
3054 MLIRContext *context, mlir::Value newValue,
3055 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3056 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3057 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3058 getAsyncOperandsMutable()));
3059}
3060
3061void acc::KernelsOp::addNumGangsOperands(
3062 MLIRContext *context, mlir::ValueRange newValues,
3063 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3065 if (getNumGangsSegmentsAttr())
3066 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3067
3068 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3069 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3070 getNumGangsMutable(), segments));
3071
3072 setNumGangsSegments(segments);
3073}
3074
3075void acc::KernelsOp::addWaitOnly(
3076 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3077 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3078 effectiveDeviceTypes));
3079}
3080void acc::KernelsOp::addWaitOperands(
3081 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3082 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3083
3085 if (getWaitOperandsSegments())
3086 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3087
3088 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3089 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3090 getWaitOperandsMutable(), segments));
3091 setWaitOperandsSegments(segments);
3092
3094 if (getHasWaitDevnumAttr())
3095 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3096 hasDevnums.insert(
3097 hasDevnums.end(),
3098 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3099 mlir::BoolAttr::get(context, hasDevnum));
3100 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3101}
3102
3103//===----------------------------------------------------------------------===//
3104// HostDataOp
3105//===----------------------------------------------------------------------===//
3106
3107LogicalResult acc::HostDataOp::verify() {
3108 if (getDataClauseOperands().empty())
3109 return emitError("at least one operand must appear on the host_data "
3110 "operation");
3111
3113 for (mlir::Value operand : getDataClauseOperands()) {
3114 auto useDeviceOp =
3115 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3116 if (!useDeviceOp)
3117 return emitError("expect data entry operation as defining op");
3118
3119 // Check for duplicate use_device clauses
3120 if (!seenVars.insert(useDeviceOp.getVar()).second)
3121 return emitError("duplicate use_device variable");
3122 }
3123 return success();
3124}
3125
3126void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3127 MLIRContext *context) {
3128 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3129}
3130
3131//===----------------------------------------------------------------------===//
3132// LoopOp
3133//===----------------------------------------------------------------------===//
3134
3135static ParseResult parseGangValue(
3136 OpAsmParser &parser, llvm::StringRef keyword,
3139 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
3140 bool &needCommaBetweenValues, bool &newValue) {
3141 if (succeeded(parser.parseOptionalKeyword(keyword))) {
3142 if (parser.parseEqual())
3143 return failure();
3144 if (parser.parseOperand(operands.emplace_back()) ||
3145 parser.parseColonType(types.emplace_back()))
3146 return failure();
3147 attributes.push_back(gangArgType);
3148 needCommaBetweenValues = true;
3149 newValue = true;
3150 }
3151 return success();
3152}
3153
3154static ParseResult parseGangClause(
3155 OpAsmParser &parser,
3157 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
3158 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
3159 mlir::ArrayAttr &gangOnlyDeviceType) {
3160 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
3161 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
3162 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
3164 bool needCommaBetweenValues = false;
3165 bool needCommaBeforeOperands = false;
3166
3167 if (failed(parser.parseOptionalLParen())) {
3168 // Gang only keyword
3169 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3170 parser.getContext(), mlir::acc::DeviceType::None));
3171 gangOnlyDeviceType =
3172 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
3173 return success();
3174 }
3175
3176 // Parse gang only attributes
3177 if (succeeded(parser.parseOptionalLSquare())) {
3178 // Parse gang only attributes
3179 if (failed(parser.parseCommaSeparatedList([&]() {
3180 if (parser.parseAttribute(
3181 gangOnlyDeviceTypeAttributes.emplace_back()))
3182 return failure();
3183 return success();
3184 })))
3185 return failure();
3186 if (parser.parseRSquare())
3187 return failure();
3188 needCommaBeforeOperands = true;
3189 }
3190
3191 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3192 mlir::acc::GangArgType::Num);
3193 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3194 mlir::acc::GangArgType::Dim);
3195 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3196 parser.getContext(), mlir::acc::GangArgType::Static);
3197
3198 do {
3199 if (needCommaBeforeOperands) {
3200 needCommaBeforeOperands = false;
3201 continue;
3202 }
3203
3204 if (failed(parser.parseLBrace()))
3205 return failure();
3206
3207 int32_t crtOperandsSize = gangOperands.size();
3208 while (true) {
3209 bool newValue = false;
3210 bool needValue = false;
3211 if (needCommaBetweenValues) {
3212 if (succeeded(parser.parseOptionalComma()))
3213 needValue = true; // expect a new value after comma.
3214 else
3215 break;
3216 }
3217
3218 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
3219 gangOperands, gangOperandsType,
3220 gangArgTypeAttributes, argNum,
3221 needCommaBetweenValues, newValue)))
3222 return failure();
3223 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
3224 gangOperands, gangOperandsType,
3225 gangArgTypeAttributes, argDim,
3226 needCommaBetweenValues, newValue)))
3227 return failure();
3228 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3229 gangOperands, gangOperandsType,
3230 gangArgTypeAttributes, argStatic,
3231 needCommaBetweenValues, newValue)))
3232 return failure();
3233
3234 if (!newValue && needValue) {
3235 parser.emitError(parser.getCurrentLocation(),
3236 "new value expected after comma");
3237 return failure();
3238 }
3239
3240 if (!newValue)
3241 break;
3242 }
3243
3244 if (gangOperands.empty())
3245 return parser.emitError(
3246 parser.getCurrentLocation(),
3247 "expect at least one of num, dim or static values");
3248
3249 if (failed(parser.parseRBrace()))
3250 return failure();
3251
3252 if (succeeded(parser.parseOptionalLSquare())) {
3253 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
3254 parser.parseRSquare())
3255 return failure();
3256 } else {
3257 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3258 parser.getContext(), mlir::acc::DeviceType::None));
3259 }
3260
3261 seg.push_back(gangOperands.size() - crtOperandsSize);
3262
3263 } while (succeeded(parser.parseOptionalComma()));
3264
3265 if (failed(parser.parseRParen()))
3266 return failure();
3267
3268 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
3269 gangArgTypeAttributes.end());
3270 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
3271 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
3272
3274 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3275 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
3276
3277 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
3278 return success();
3279}
3280
3282 mlir::OperandRange operands, mlir::TypeRange types,
3283 std::optional<mlir::ArrayAttr> gangArgTypes,
3284 std::optional<mlir::ArrayAttr> deviceTypes,
3285 std::optional<mlir::DenseI32ArrayAttr> segments,
3286 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3287
3288 if (operands.begin() == operands.end() &&
3289 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
3290 return;
3291 }
3292
3293 p << "(";
3294
3295 printDeviceTypes(p, gangOnlyDeviceTypes);
3296
3297 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
3298 hasDeviceTypeValues(deviceTypes))
3299 p << ", ";
3300
3301 if (hasDeviceTypeValues(deviceTypes)) {
3302 unsigned opIdx = 0;
3303 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
3304 p << "{";
3305 llvm::interleaveComma(
3306 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
3307 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3308 (*gangArgTypes)[opIdx]);
3309 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3310 p << LoopOp::getGangNumKeyword();
3311 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3312 p << LoopOp::getGangDimKeyword();
3313 else if (gangArgTypeAttr.getValue() ==
3314 mlir::acc::GangArgType::Static)
3315 p << LoopOp::getGangStaticKeyword();
3316 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
3317 ++opIdx;
3318 });
3319 p << "}";
3320 printSingleDeviceType(p, it.value());
3321 });
3322 }
3323 p << ")";
3324}
3325
3327 std::optional<mlir::ArrayAttr> segments,
3328 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3329 if (!segments)
3330 return false;
3331 for (auto attr : *segments) {
3332 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3333 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3334 return true;
3335 }
3336 return false;
3337}
3338
3339/// Check for duplicates in the DeviceType array attribute.
3340/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3341static std::optional<mlir::acc::DeviceType>
3342checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3343 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3344 if (!deviceTypes)
3345 return std::nullopt;
3346 for (auto attr : deviceTypes) {
3347 auto deviceTypeAttr =
3348 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3349 if (!deviceTypeAttr)
3350 return mlir::acc::DeviceType::None;
3351 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3352 return deviceTypeAttr.getValue();
3353 }
3354 return std::nullopt;
3355}
3356
3357LogicalResult acc::LoopOp::verify() {
3358 if (getUpperbound().size() != getStep().size())
3359 return emitError() << "number of upperbounds expected to be the same as "
3360 "number of steps";
3361
3362 if (getUpperbound().size() != getLowerbound().size())
3363 return emitError() << "number of upperbounds expected to be the same as "
3364 "number of lowerbounds";
3365
3366 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3367 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3368 return emitError() << "inclusiveUpperbound size is expected to be the same"
3369 << " as upperbound size";
3370
3371 // Check collapse
3372 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3373 return emitOpError() << "collapse device_type attr must be define when"
3374 << " collapse attr is present";
3375
3376 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3377 getCollapseAttr().getValue().size() !=
3378 getCollapseDeviceTypeAttr().getValue().size())
3379 return emitOpError() << "collapse attribute count must match collapse"
3380 << " device_type count";
3381 if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3382 return emitOpError() << "duplicate device_type `"
3383 << acc::stringifyDeviceType(*duplicateDeviceType)
3384 << "` found in collapseDeviceType attribute";
3385
3386 // Check gang
3387 if (!getGangOperands().empty()) {
3388 if (!getGangOperandsArgType())
3389 return emitOpError() << "gangOperandsArgType attribute must be defined"
3390 << " when gang operands are present";
3391
3392 if (getGangOperands().size() !=
3393 getGangOperandsArgTypeAttr().getValue().size())
3394 return emitOpError() << "gangOperandsArgType attribute count must match"
3395 << " gangOperands count";
3396 }
3397 if (getGangAttr()) {
3398 if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3399 return emitOpError() << "duplicate device_type `"
3400 << acc::stringifyDeviceType(*duplicateDeviceType)
3401 << "` found in gang attribute";
3402 }
3403
3405 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3406 getGangOperandsDeviceTypeAttr(), "gang")))
3407 return failure();
3408
3409 // Check worker
3410 if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3411 return emitOpError() << "duplicate device_type `"
3412 << acc::stringifyDeviceType(*duplicateDeviceType)
3413 << "` found in worker attribute";
3414 if (auto duplicateDeviceType =
3415 checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3416 return emitOpError() << "duplicate device_type `"
3417 << acc::stringifyDeviceType(*duplicateDeviceType)
3418 << "` found in workerNumOperandsDeviceType attribute";
3419 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3420 getWorkerNumOperandsDeviceTypeAttr(),
3421 "worker")))
3422 return failure();
3423
3424 // Check vector
3425 if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3426 return emitOpError() << "duplicate device_type `"
3427 << acc::stringifyDeviceType(*duplicateDeviceType)
3428 << "` found in vector attribute";
3429 if (auto duplicateDeviceType =
3430 checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3431 return emitOpError() << "duplicate device_type `"
3432 << acc::stringifyDeviceType(*duplicateDeviceType)
3433 << "` found in vectorOperandsDeviceType attribute";
3434 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3435 getVectorOperandsDeviceTypeAttr(),
3436 "vector")))
3437 return failure();
3438
3440 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3441 getTileOperandsDeviceTypeAttr(), "tile")))
3442 return failure();
3443
3444 // auto, independent and seq attribute are mutually exclusive.
3445 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3446 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3447 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3448 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3449 return emitError() << "only one of auto, independent, seq can be present "
3450 "at the same time";
3451 }
3452
3453 // Check that at least one of auto, independent, or seq is present
3454 // for the device-independent default clauses.
3455 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3456 return attr.getValue() == mlir::acc::DeviceType::None;
3457 };
3458 bool hasDefaultSeq =
3459 getSeqAttr()
3460 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3461 hasDeviceNone)
3462 : false;
3463 bool hasDefaultIndependent =
3464 getIndependentAttr()
3465 ? llvm::any_of(
3466 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3467 hasDeviceNone)
3468 : false;
3469 bool hasDefaultAuto =
3470 getAuto_Attr()
3471 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3472 hasDeviceNone)
3473 : false;
3474 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3475 return emitError()
3476 << "at least one of auto, independent, seq must be present";
3477 }
3478
3479 // Gang, worker and vector are incompatible with seq.
3480 if (getSeqAttr()) {
3481 for (auto attr : getSeqAttr()) {
3482 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3483 if (hasVector(deviceTypeAttr.getValue()) ||
3484 getVectorValue(deviceTypeAttr.getValue()) ||
3485 hasWorker(deviceTypeAttr.getValue()) ||
3486 getWorkerValue(deviceTypeAttr.getValue()) ||
3487 hasGang(deviceTypeAttr.getValue()) ||
3488 getGangValue(mlir::acc::GangArgType::Num,
3489 deviceTypeAttr.getValue()) ||
3490 getGangValue(mlir::acc::GangArgType::Dim,
3491 deviceTypeAttr.getValue()) ||
3492 getGangValue(mlir::acc::GangArgType::Static,
3493 deviceTypeAttr.getValue()))
3494 return emitError() << "gang, worker or vector cannot appear with seq";
3495 }
3496 }
3497
3498 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
3499 mlir::acc::PrivateRecipeOp>(
3500 *this, getPrivateOperands(), "private")))
3501 return failure();
3502
3503 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
3504 mlir::acc::FirstprivateRecipeOp>(
3505 *this, getFirstprivateOperands(), "firstprivate")))
3506 return failure();
3507
3508 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
3509 mlir::acc::ReductionRecipeOp>(
3510 *this, getReductionOperands(), "reduction")))
3511 return failure();
3512
3513 if (getCombined().has_value() &&
3514 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3515 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3516 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3517 return emitError("unexpected combined constructs attribute");
3518 }
3519
3520 // Check non-empty body().
3521 if (getRegion().empty())
3522 return emitError("expected non-empty body.");
3523
3524 if (getUnstructured()) {
3525 if (!isContainerLike())
3526 return emitError(
3527 "unstructured acc.loop must not have induction variables");
3528 } else if (isContainerLike()) {
3529 // When it is container-like - it is expected to hold a loop-like operation.
3530 // Obtain the maximum collapse count - we use this to check that there
3531 // are enough loops contained.
3532 uint64_t collapseCount = getCollapseValue().value_or(1);
3533 if (getCollapseAttr()) {
3534 for (auto collapseEntry : getCollapseAttr()) {
3535 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3536 if (intAttr.getValue().getZExtValue() > collapseCount)
3537 collapseCount = intAttr.getValue().getZExtValue();
3538 }
3539 }
3540
3541 // We want to check that we find enough loop-like operations inside.
3542 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3543 // level.
3544 mlir::Operation *expectedParent = this->getOperation();
3545 bool foundSibling = false;
3546 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3547 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3548 // This effectively checks that we are not looking at a sibling loop.
3549 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3550 expectedParent) {
3551 foundSibling = true;
3553 }
3554
3555 collapseCount--;
3556 expectedParent = op;
3557 }
3558 // We found enough contained loops.
3559 if (collapseCount == 0)
3562 });
3563
3564 if (foundSibling)
3565 return emitError("found sibling loops inside container-like acc.loop");
3566 if (collapseCount != 0)
3567 return emitError("failed to find enough loop-like operations inside "
3568 "container-like acc.loop");
3569 }
3570
3571 return success();
3572}
3573
3574unsigned LoopOp::getNumDataOperands() {
3575 return getReductionOperands().size() + getPrivateOperands().size() +
3576 getFirstprivateOperands().size();
3577}
3578
3579Value LoopOp::getDataOperand(unsigned i) {
3580 unsigned numOptional =
3581 getLowerbound().size() + getUpperbound().size() + getStep().size();
3582 numOptional += getGangOperands().size();
3583 numOptional += getVectorOperands().size();
3584 numOptional += getWorkerNumOperands().size();
3585 numOptional += getTileOperands().size();
3586 numOptional += getCacheOperands().size();
3587 return getOperand(numOptional + i);
3588}
3589
3590bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3591
3592bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3593 return hasDeviceType(getAuto_(), deviceType);
3594}
3595
3596bool LoopOp::hasIndependent() {
3597 return hasIndependent(mlir::acc::DeviceType::None);
3598}
3599
3600bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3601 return hasDeviceType(getIndependent(), deviceType);
3602}
3603
3604bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3605
3606bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3607 return hasDeviceType(getSeq(), deviceType);
3608}
3609
3610mlir::Value LoopOp::getVectorValue() {
3611 return getVectorValue(mlir::acc::DeviceType::None);
3612}
3613
3614mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3615 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3616 getVectorOperands(), deviceType);
3617}
3618
3619bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3620
3621bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3622 return hasDeviceType(getVector(), deviceType);
3623}
3624
3625mlir::Value LoopOp::getWorkerValue() {
3626 return getWorkerValue(mlir::acc::DeviceType::None);
3627}
3628
3629mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3630 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3631 getWorkerNumOperands(), deviceType);
3632}
3633
3634bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3635
3636bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3637 return hasDeviceType(getWorker(), deviceType);
3638}
3639
3640mlir::Operation::operand_range LoopOp::getTileValues() {
3641 return getTileValues(mlir::acc::DeviceType::None);
3642}
3643
3645LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3646 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3647 getTileOperandsSegments(), deviceType);
3648}
3649
3650std::optional<int64_t> LoopOp::getCollapseValue() {
3651 return getCollapseValue(mlir::acc::DeviceType::None);
3652}
3653
3654std::optional<int64_t>
3655LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3656 if (!getCollapseAttr())
3657 return std::nullopt;
3658 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3659 auto intAttr =
3660 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3661 return intAttr.getValue().getZExtValue();
3662 }
3663 return std::nullopt;
3664}
3665
3666mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3667 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3668}
3669
3670mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3671 mlir::acc::DeviceType deviceType) {
3672 if (getGangOperands().empty())
3673 return {};
3674 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3675 int32_t nbOperandsBefore = 0;
3676 for (unsigned i = 0; i < *pos; ++i)
3677 nbOperandsBefore += (*getGangOperandsSegments())[i];
3679 getGangOperands()
3680 .drop_front(nbOperandsBefore)
3681 .take_front((*getGangOperandsSegments())[*pos]);
3682
3683 int32_t argTypeIdx = nbOperandsBefore;
3684 for (auto value : values) {
3685 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3686 (*getGangOperandsArgType())[argTypeIdx]);
3687 if (gangArgTypeAttr.getValue() == gangArgType)
3688 return value;
3689 ++argTypeIdx;
3690 }
3691 }
3692 return {};
3693}
3694
3695bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3696
3697bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3698 return hasDeviceType(getGang(), deviceType);
3699}
3700
3701llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3702 return {&getRegion()};
3703}
3704
3705/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3706/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3707/// `(` ssa-id-and-type-list `)`
3708/// region
3709ParseResult
3712 SmallVectorImpl<Type> &lowerboundType,
3714 SmallVectorImpl<Type> &upperboundType,
3716 SmallVectorImpl<Type> &stepType) {
3717
3719 if (succeeded(
3720 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3721 if (parser.parseLParen() ||
3722 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3723 /*allowType=*/true) ||
3724 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3725 parser.parseOperandList(lowerbound, inductionVars.size(),
3727 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3728 parser.parseKeyword("to") || parser.parseLParen() ||
3729 parser.parseOperandList(upperbound, inductionVars.size(),
3731 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3732 parser.parseKeyword("step") || parser.parseLParen() ||
3733 parser.parseOperandList(step, inductionVars.size(),
3735 parser.parseColonTypeList(stepType) || parser.parseRParen())
3736 return failure();
3737 }
3738 return parser.parseRegion(region, inductionVars);
3739}
3740
3742 ValueRange lowerbound, TypeRange lowerboundType,
3743 ValueRange upperbound, TypeRange upperboundType,
3744 ValueRange steps, TypeRange stepType) {
3745 ValueRange regionArgs = region.front().getArguments();
3746 if (!regionArgs.empty()) {
3747 p << acc::LoopOp::getControlKeyword() << "(";
3748 llvm::interleaveComma(regionArgs, p,
3749 [&p](Value v) { p << v << " : " << v.getType(); });
3750 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3751 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3752 << " : " << stepType << ") ";
3753 }
3754 p.printRegion(region, /*printEntryBlockArgs=*/false);
3755}
3756
3757void acc::LoopOp::addSeq(MLIRContext *context,
3758 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3759 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3760 effectiveDeviceTypes));
3761}
3762
3763void acc::LoopOp::addIndependent(
3764 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3765 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3766 context, getIndependentAttr(), effectiveDeviceTypes));
3767}
3768
3769void acc::LoopOp::addAuto(MLIRContext *context,
3770 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3771 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3772 effectiveDeviceTypes));
3773}
3774
3775void acc::LoopOp::setCollapseForDeviceTypes(
3776 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3777 llvm::APInt value) {
3780
3781 assert((getCollapseAttr() == nullptr) ==
3782 (getCollapseDeviceTypeAttr() == nullptr));
3783 assert(value.getBitWidth() == 64);
3784
3785 if (getCollapseAttr()) {
3786 for (const auto &existing :
3787 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3788 newValues.push_back(std::get<0>(existing));
3789 newDeviceTypes.push_back(std::get<1>(existing));
3790 }
3791 }
3792
3793 if (effectiveDeviceTypes.empty()) {
3794 // If the effective device-types list is empty, this is before there are any
3795 // being applied by device_type, so this should be added as a 'none'.
3796 newValues.push_back(
3797 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3798 newDeviceTypes.push_back(
3799 acc::DeviceTypeAttr::get(context, DeviceType::None));
3800 } else {
3801 for (DeviceType dt : effectiveDeviceTypes) {
3802 newValues.push_back(
3803 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3804 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3805 }
3806 }
3807
3808 setCollapseAttr(ArrayAttr::get(context, newValues));
3809 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3810}
3811
3812void acc::LoopOp::setTileForDeviceTypes(
3813 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3814 ValueRange values) {
3816 if (getTileOperandsSegments())
3817 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3818
3819 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3820 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3821 getTileOperandsMutable(), segments));
3822
3823 setTileOperandsSegments(segments);
3824}
3825
3826void acc::LoopOp::addVectorOperand(
3827 MLIRContext *context, mlir::Value newValue,
3828 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3829 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3830 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3831 newValue, getVectorOperandsMutable()));
3832}
3833
3834void acc::LoopOp::addEmptyVector(
3835 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3836 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3837 effectiveDeviceTypes));
3838}
3839
3840void acc::LoopOp::addWorkerNumOperand(
3841 MLIRContext *context, mlir::Value newValue,
3842 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3843 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3844 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3845 newValue, getWorkerNumOperandsMutable()));
3846}
3847
3848void acc::LoopOp::addEmptyWorker(
3849 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3850 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3851 effectiveDeviceTypes));
3852}
3853
3854void acc::LoopOp::addEmptyGang(
3855 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3856 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3857 effectiveDeviceTypes));
3858}
3859
3860bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3861 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3862 return attr.getValue() == dt;
3863 };
3864 auto testFromArr = [=](ArrayAttr arr) -> bool {
3865 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3866 };
3867
3868 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3869 return true;
3870 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3871 return true;
3872 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3873 return true;
3874
3875 return false;
3876}
3877
3878bool acc::LoopOp::hasDefaultGangWorkerVector() {
3879 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3880 hasGang() || getGangValue(GangArgType::Num) ||
3881 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3882}
3883
3884acc::LoopParMode
3885acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3886 if (hasSeq(deviceType))
3887 return LoopParMode::loop_seq;
3888 if (hasAuto(deviceType))
3889 return LoopParMode::loop_auto;
3890 if (hasIndependent(deviceType))
3891 return LoopParMode::loop_independent;
3892 if (hasSeq())
3893 return LoopParMode::loop_seq;
3894 if (hasAuto())
3895 return LoopParMode::loop_auto;
3896 assert(hasIndependent() &&
3897 "loop must have default auto, seq, or independent");
3898 return LoopParMode::loop_independent;
3899}
3900
3901void acc::LoopOp::addGangOperands(
3902 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3905 if (std::optional<ArrayRef<int32_t>> existingSegments =
3906 getGangOperandsSegments())
3907 llvm::copy(*existingSegments, std::back_inserter(segments));
3908
3909 unsigned beforeCount = segments.size();
3910
3911 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3912 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3913 getGangOperandsMutable(), segments));
3914
3915 setGangOperandsSegments(segments);
3916
3917 // This is a bit of extra work to make sure we update the 'types' correctly by
3918 // adding to the types collection the correct number of times. We could
3919 // potentially add something similar to the
3920 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3921 // excessive for a one-off case.
3922 unsigned numAdded = segments.size() - beforeCount;
3923
3924 if (numAdded > 0) {
3926 if (getGangOperandsArgTypeAttr())
3927 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3928
3929 for (auto i : llvm::index_range(0u, numAdded)) {
3930 llvm::transform(argTypes, std::back_inserter(gangTypes),
3931 [=](mlir::acc::GangArgType gangTy) {
3932 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3933 });
3934 (void)i;
3935 }
3936
3937 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3938 }
3939}
3940
3941void acc::LoopOp::addPrivatization(MLIRContext *context,
3942 mlir::acc::PrivateOp op,
3943 mlir::acc::PrivateRecipeOp recipe) {
3944 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3945 getPrivateOperandsMutable().append(op.getResult());
3946}
3947
3948void acc::LoopOp::addFirstPrivatization(
3949 MLIRContext *context, mlir::acc::FirstprivateOp op,
3950 mlir::acc::FirstprivateRecipeOp recipe) {
3951 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3952 getFirstprivateOperandsMutable().append(op.getResult());
3953}
3954
3955void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3956 mlir::acc::ReductionRecipeOp recipe) {
3957 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3958 getReductionOperandsMutable().append(op.getResult());
3959}
3960
3961//===----------------------------------------------------------------------===//
3962// DataOp
3963//===----------------------------------------------------------------------===//
3964
3965LogicalResult acc::DataOp::verify() {
3966 // 2.6.5. Data Construct restriction
3967 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3968 // attach, or default clause must appear on a data construct.
3969 if (getOperands().empty() && !getDefaultAttr())
3970 return emitError("at least one operand or the default attribute "
3971 "must appear on the data operation");
3972
3973 for (mlir::Value operand : getDataClauseOperands())
3974 if (isa<BlockArgument>(operand) ||
3975 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3976 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3977 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3978 operand.getDefiningOp()))
3979 return emitError("expect data entry/exit operation or acc.getdeviceptr "
3980 "as defining op");
3981
3983 return failure();
3984
3985 return success();
3986}
3987
3988unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3989
3990Value DataOp::getDataOperand(unsigned i) {
3991 unsigned numOptional = getIfCond() ? 1 : 0;
3992 numOptional += getAsyncOperands().size() ? 1 : 0;
3993 numOptional += getWaitOperands().size();
3994 return getOperand(numOptional + i);
3995}
3996
3997bool acc::DataOp::hasAsyncOnly() {
3998 return hasAsyncOnly(mlir::acc::DeviceType::None);
3999}
4000
4001bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4002 return hasDeviceType(getAsyncOnly(), deviceType);
4003}
4004
4005mlir::Value DataOp::getAsyncValue() {
4006 return getAsyncValue(mlir::acc::DeviceType::None);
4007}
4008
4009mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4011 getAsyncOperands(), deviceType);
4012}
4013
4014bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
4015
4016bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4017 return hasDeviceType(getWaitOnly(), deviceType);
4018}
4019
4020mlir::Operation::operand_range DataOp::getWaitValues() {
4021 return getWaitValues(mlir::acc::DeviceType::None);
4022}
4023
4025DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4027 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4028 getHasWaitDevnum(), deviceType);
4029}
4030
4031mlir::Value DataOp::getWaitDevnum() {
4032 return getWaitDevnum(mlir::acc::DeviceType::None);
4033}
4034
4035mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4036 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4037 getWaitOperandsSegments(), getHasWaitDevnum(),
4038 deviceType);
4039}
4040
4041void acc::DataOp::addAsyncOnly(
4042 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4043 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4044 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4045}
4046
4047void acc::DataOp::addAsyncOperand(
4048 MLIRContext *context, mlir::Value newValue,
4049 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4050 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4051 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4052 getAsyncOperandsMutable()));
4053}
4054
4055void acc::DataOp::addWaitOnly(MLIRContext *context,
4056 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4057 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4058 effectiveDeviceTypes));
4059}
4060
4061void acc::DataOp::addWaitOperands(
4062 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4063 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4064
4066 if (getWaitOperandsSegments())
4067 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4068
4069 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4070 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4071 getWaitOperandsMutable(), segments));
4072 setWaitOperandsSegments(segments);
4073
4075 if (getHasWaitDevnumAttr())
4076 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4077 hasDevnums.insert(
4078 hasDevnums.end(),
4079 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4080 mlir::BoolAttr::get(context, hasDevnum));
4081 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4082}
4083
4084//===----------------------------------------------------------------------===//
4085// ExitDataOp
4086//===----------------------------------------------------------------------===//
4087
4088LogicalResult acc::ExitDataOp::verify() {
4089 // 2.6.6. Data Exit Directive restriction
4090 // At least one copyout, delete, or detach clause must appear on an exit data
4091 // directive.
4092 if (getDataClauseOperands().empty())
4093 return emitError("at least one operand must be present in dataOperands on "
4094 "the exit data operation");
4095
4096 // The async attribute represent the async clause without value. Therefore the
4097 // attribute and operand cannot appear at the same time.
4098 if (getAsyncOperand() && getAsync())
4099 return emitError("async attribute cannot appear with asyncOperand");
4100
4101 // The wait attribute represent the wait clause without values. Therefore the
4102 // attribute and operands cannot appear at the same time.
4103 if (!getWaitOperands().empty() && getWait())
4104 return emitError("wait attribute cannot appear with waitOperands");
4105
4106 if (getWaitDevnum() && getWaitOperands().empty())
4107 return emitError("wait_devnum cannot appear without waitOperands");
4108
4109 return success();
4110}
4111
4112unsigned ExitDataOp::getNumDataOperands() {
4113 return getDataClauseOperands().size();
4114}
4115
4116Value ExitDataOp::getDataOperand(unsigned i) {
4117 unsigned numOptional = getIfCond() ? 1 : 0;
4118 numOptional += getAsyncOperand() ? 1 : 0;
4119 numOptional += getWaitDevnum() ? 1 : 0;
4120 return getOperand(getWaitOperands().size() + numOptional + i);
4121}
4122
4123void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4124 MLIRContext *context) {
4125 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
4126}
4127
4128void ExitDataOp::addAsyncOnly(MLIRContext *context,
4129 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4130 assert(effectiveDeviceTypes.empty());
4131 assert(!getAsyncAttr());
4132 assert(!getAsyncOperand());
4133
4134 setAsyncAttr(mlir::UnitAttr::get(context));
4135}
4136
4137void ExitDataOp::addAsyncOperand(
4138 MLIRContext *context, mlir::Value newValue,
4139 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4140 assert(effectiveDeviceTypes.empty());
4141 assert(!getAsyncAttr());
4142 assert(!getAsyncOperand());
4143
4144 getAsyncOperandMutable().append(newValue);
4145}
4146
4147void ExitDataOp::addWaitOnly(MLIRContext *context,
4148 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4149 assert(effectiveDeviceTypes.empty());
4150 assert(!getWaitAttr());
4151 assert(getWaitOperands().empty());
4152 assert(!getWaitDevnum());
4153
4154 setWaitAttr(mlir::UnitAttr::get(context));
4155}
4156
4157void ExitDataOp::addWaitOperands(
4158 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4159 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4160 assert(effectiveDeviceTypes.empty());
4161 assert(!getWaitAttr());
4162 assert(getWaitOperands().empty());
4163 assert(!getWaitDevnum());
4164
4165 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4166 // operands list.
4167 if (hasDevnum) {
4168 getWaitDevnumMutable().append(newValues.front());
4169 newValues = newValues.drop_front();
4170 }
4171
4172 getWaitOperandsMutable().append(newValues);
4173}
4174
4175//===----------------------------------------------------------------------===//
4176// EnterDataOp
4177//===----------------------------------------------------------------------===//
4178
4179LogicalResult acc::EnterDataOp::verify() {
4180 // 2.6.6. Data Enter Directive restriction
4181 // At least one copyin, create, or attach clause must appear on an enter data
4182 // directive.
4183 if (getDataClauseOperands().empty())
4184 return emitError("at least one operand must be present in dataOperands on "
4185 "the enter data operation");
4186
4187 // The async attribute represent the async clause without value. Therefore the
4188 // attribute and operand cannot appear at the same time.
4189 if (getAsyncOperand() && getAsync())
4190 return emitError("async attribute cannot appear with asyncOperand");
4191
4192 // The wait attribute represent the wait clause without values. Therefore the
4193 // attribute and operands cannot appear at the same time.
4194 if (!getWaitOperands().empty() && getWait())
4195 return emitError("wait attribute cannot appear with waitOperands");
4196
4197 if (getWaitDevnum() && getWaitOperands().empty())
4198 return emitError("wait_devnum cannot appear without waitOperands");
4199
4200 for (mlir::Value operand : getDataClauseOperands())
4201 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4202 operand.getDefiningOp()))
4203 return emitError("expect data entry operation as defining op");
4204
4205 return success();
4206}
4207
4208unsigned EnterDataOp::getNumDataOperands() {
4209 return getDataClauseOperands().size();
4210}
4211
4212Value EnterDataOp::getDataOperand(unsigned i) {
4213 unsigned numOptional = getIfCond() ? 1 : 0;
4214 numOptional += getAsyncOperand() ? 1 : 0;
4215 numOptional += getWaitDevnum() ? 1 : 0;
4216 return getOperand(getWaitOperands().size() + numOptional + i);
4217}
4218
4219void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4220 MLIRContext *context) {
4221 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
4222}
4223
4224void EnterDataOp::addAsyncOnly(
4225 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4226 assert(effectiveDeviceTypes.empty());
4227 assert(!getAsyncAttr());
4228 assert(!getAsyncOperand());
4229
4230 setAsyncAttr(mlir::UnitAttr::get(context));
4231}
4232
4233void EnterDataOp::addAsyncOperand(
4234 MLIRContext *context, mlir::Value newValue,
4235 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4236 assert(effectiveDeviceTypes.empty());
4237 assert(!getAsyncAttr());
4238 assert(!getAsyncOperand());
4239
4240 getAsyncOperandMutable().append(newValue);
4241}
4242
4243void EnterDataOp::addWaitOnly(MLIRContext *context,
4244 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4245 assert(effectiveDeviceTypes.empty());
4246 assert(!getWaitAttr());
4247 assert(getWaitOperands().empty());
4248 assert(!getWaitDevnum());
4249
4250 setWaitAttr(mlir::UnitAttr::get(context));
4251}
4252
4253void EnterDataOp::addWaitOperands(
4254 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4255 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4256 assert(effectiveDeviceTypes.empty());
4257 assert(!getWaitAttr());
4258 assert(getWaitOperands().empty());
4259 assert(!getWaitDevnum());
4260
4261 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4262 // operands list.
4263 if (hasDevnum) {
4264 getWaitDevnumMutable().append(newValues.front());
4265 newValues = newValues.drop_front();
4266 }
4267
4268 getWaitOperandsMutable().append(newValues);
4269}
4270
4271//===----------------------------------------------------------------------===//
4272// AtomicReadOp
4273//===----------------------------------------------------------------------===//
4274
4275LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
4276
4277//===----------------------------------------------------------------------===//
4278// AtomicWriteOp
4279//===----------------------------------------------------------------------===//
4280
4281LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
4282
4283//===----------------------------------------------------------------------===//
4284// AtomicUpdateOp
4285//===----------------------------------------------------------------------===//
4286
4287LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4288 PatternRewriter &rewriter) {
4289 if (op.isNoOp()) {
4290 rewriter.eraseOp(op);
4291 return success();
4292 }
4293
4294 if (Value writeVal = op.getWriteOpVal()) {
4295 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
4296 op.getIfCond());
4297 return success();
4298 }
4299
4300 return failure();
4301}
4302
4303LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
4304
4305LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4306
4307//===----------------------------------------------------------------------===//
4308// AtomicCaptureOp
4309//===----------------------------------------------------------------------===//
4310
4311AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4312 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4313 return op;
4314 return dyn_cast<AtomicReadOp>(getSecondOp());
4315}
4316
4317AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4318 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4319 return op;
4320 return dyn_cast<AtomicWriteOp>(getSecondOp());
4321}
4322
4323AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4324 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4325 return op;
4326 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4327}
4328
4329LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
4330
4331//===----------------------------------------------------------------------===//
4332// DeclareEnterOp
4333//===----------------------------------------------------------------------===//
4334
4335template <typename Op>
4336static LogicalResult
4338 bool requireAtLeastOneOperand = true) {
4339 if (operands.empty() && requireAtLeastOneOperand)
4340 return emitError(
4341 op->getLoc(),
4342 "at least one operand must appear on the declare operation");
4343
4344 for (mlir::Value operand : operands) {
4345 if (isa<BlockArgument>(operand) ||
4346 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4347 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4348 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4349 operand.getDefiningOp()))
4350 return op.emitError(
4351 "expect valid declare data entry operation or acc.getdeviceptr "
4352 "as defining op");
4353
4354 mlir::Value var{getVar(operand.getDefiningOp())};
4355 assert(var && "declare operands can only be data entry operations which "
4356 "must have var");
4357 (void)var;
4358 std::optional<mlir::acc::DataClause> dataClauseOptional{
4359 getDataClause(operand.getDefiningOp())};
4360 assert(dataClauseOptional.has_value() &&
4361 "declare operands can only be data entry operations which must have "
4362 "dataClause");
4363 (void)dataClauseOptional;
4364 }
4365
4366 return success();
4367}
4368
4369LogicalResult acc::DeclareEnterOp::verify() {
4370 return checkDeclareOperands(*this, this->getDataClauseOperands());
4371}
4372
4373//===----------------------------------------------------------------------===//
4374// DeclareExitOp
4375//===----------------------------------------------------------------------===//
4376
4377LogicalResult acc::DeclareExitOp::verify() {
4378 if (getToken())
4379 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4380 /*requireAtLeastOneOperand=*/false);
4381 return checkDeclareOperands(*this, this->getDataClauseOperands());
4382}
4383
4384//===----------------------------------------------------------------------===//
4385// DeclareOp
4386//===----------------------------------------------------------------------===//
4387
4388LogicalResult acc::DeclareOp::verify() {
4389 return checkDeclareOperands(*this, this->getDataClauseOperands());
4390}
4391
4392//===----------------------------------------------------------------------===//
4393// RoutineOp
4394//===----------------------------------------------------------------------===//
4395
4396static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4397 acc::DeviceType dtype) {
4398 unsigned parallelism = 0;
4399 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4400 parallelism += op.hasWorker(dtype) ? 1 : 0;
4401 parallelism += op.hasVector(dtype) ? 1 : 0;
4402 parallelism += op.hasSeq(dtype) ? 1 : 0;
4403 return parallelism;
4404}
4405
4406LogicalResult acc::RoutineOp::verify() {
4407 unsigned baseParallelism =
4408 getParallelismForDeviceType(*this, acc::DeviceType::None);
4409
4410 if (baseParallelism > 1)
4411 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4412 "be present at the same time";
4413
4414 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4415 ++dtypeInt) {
4416 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4417 if (dtype == acc::DeviceType::None)
4418 continue;
4419 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4420
4421 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4422 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4423 "be present at the same time for device_type `"
4424 << acc::stringifyDeviceType(dtype) << "`";
4425 }
4426
4427 return success();
4428}
4429
4430static ParseResult parseBindName(OpAsmParser &parser,
4431 mlir::ArrayAttr &bindIdName,
4432 mlir::ArrayAttr &bindStrName,
4433 mlir::ArrayAttr &deviceIdTypes,
4434 mlir::ArrayAttr &deviceStrTypes) {
4435 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4436 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4437 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4438 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4439
4440 if (failed(parser.parseCommaSeparatedList([&]() {
4441 mlir::Attribute newAttr;
4442 bool isSymbolRefAttr;
4443 auto parseResult = parser.parseAttribute(newAttr);
4444 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4445 bindIdNameAttrs.push_back(symbolRefAttr);
4446 isSymbolRefAttr = true;
4447 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4448 bindStrNameAttrs.push_back(stringAttr);
4449 isSymbolRefAttr = false;
4450 }
4451 if (parseResult)
4452 return failure();
4453 if (failed(parser.parseOptionalLSquare())) {
4454 if (isSymbolRefAttr) {
4455 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4456 parser.getContext(), mlir::acc::DeviceType::None));
4457 } else {
4458 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4459 parser.getContext(), mlir::acc::DeviceType::None));
4460 }
4461 } else {
4462 if (isSymbolRefAttr) {
4463 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4464 parser.parseRSquare())
4465 return failure();
4466 } else {
4467 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4468 parser.parseRSquare())
4469 return failure();
4470 }
4471 }
4472 return success();
4473 })))
4474 return failure();
4475
4476 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4477 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4478 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4479 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4480
4481 return success();
4482}
4483
4485 std::optional<mlir::ArrayAttr> bindIdName,
4486 std::optional<mlir::ArrayAttr> bindStrName,
4487 std::optional<mlir::ArrayAttr> deviceIdTypes,
4488 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4489 // Create combined vectors for all bind names and device types
4492
4493 // Append bindIdName and deviceIdTypes
4494 if (hasDeviceTypeValues(deviceIdTypes)) {
4495 allBindNames.append(bindIdName->begin(), bindIdName->end());
4496 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4497 }
4498
4499 // Append bindStrName and deviceStrTypes
4500 if (hasDeviceTypeValues(deviceStrTypes)) {
4501 allBindNames.append(bindStrName->begin(), bindStrName->end());
4502 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4503 }
4504
4505 // Print the combined sequence
4506 if (!allBindNames.empty())
4507 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4508 [&](const auto &pair) {
4509 p << std::get<0>(pair);
4510 printSingleDeviceType(p, std::get<1>(pair));
4511 });
4512}
4513
4514static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4515 mlir::ArrayAttr &gang,
4516 mlir::ArrayAttr &gangDim,
4517 mlir::ArrayAttr &gangDimDeviceTypes) {
4518
4519 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4520 gangDimDeviceTypeAttrs;
4521 bool needCommaBeforeOperands = false;
4522
4523 // Gang keyword only
4524 if (failed(parser.parseOptionalLParen())) {
4525 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4526 parser.getContext(), mlir::acc::DeviceType::None));
4527 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4528 return success();
4529 }
4530
4531 // Parse keyword only attributes
4532 if (succeeded(parser.parseOptionalLSquare())) {
4533 if (failed(parser.parseCommaSeparatedList([&]() {
4534 if (parser.parseAttribute(gangAttrs.emplace_back()))
4535 return failure();
4536 return success();
4537 })))
4538 return failure();
4539 if (parser.parseRSquare())
4540 return failure();
4541 needCommaBeforeOperands = true;
4542 }
4543
4544 if (needCommaBeforeOperands && failed(parser.parseComma()))
4545 return failure();
4546
4547 if (failed(parser.parseCommaSeparatedList([&]() {
4548 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4549 parser.parseColon() ||
4550 parser.parseAttribute(gangDimAttrs.emplace_back()))
4551 return failure();
4552 if (succeeded(parser.parseOptionalLSquare())) {
4553 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4554 parser.parseRSquare())
4555 return failure();
4556 } else {
4557 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4558 parser.getContext(), mlir::acc::DeviceType::None));
4559 }
4560 return success();
4561 })))
4562 return failure();
4563
4564 if (failed(parser.parseRParen()))
4565 return failure();
4566
4567 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4568 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4569 gangDimDeviceTypes =
4570 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4571
4572 return success();
4573}
4574
4576 std::optional<mlir::ArrayAttr> gang,
4577 std::optional<mlir::ArrayAttr> gangDim,
4578 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4579
4580 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4581 gang->size() == 1) {
4582 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4583 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4584 return;
4585 }
4586
4587 p << "(";
4588
4589 printDeviceTypes(p, gang);
4590
4591 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4592 p << ", ";
4593
4594 if (hasDeviceTypeValues(gangDimDeviceTypes))
4595 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4596 [&](const auto &pair) {
4597 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4598 p << std::get<0>(pair);
4599 printSingleDeviceType(p, std::get<1>(pair));
4600 });
4601
4602 p << ")";
4603}
4604
4605static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4606 mlir::ArrayAttr &deviceTypes) {
4608 // Keyword only
4609 if (failed(parser.parseOptionalLParen())) {
4610 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4611 parser.getContext(), mlir::acc::DeviceType::None));
4612 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4613 return success();
4614 }
4615
4616 // Parse device type attributes
4617 if (succeeded(parser.parseOptionalLSquare())) {
4618 if (failed(parser.parseCommaSeparatedList([&]() {
4619 if (parser.parseAttribute(attributes.emplace_back()))
4620 return failure();
4621 return success();
4622 })))
4623 return failure();
4624 if (parser.parseRSquare() || parser.parseRParen())
4625 return failure();
4626 }
4627 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4628 return success();
4629}
4630
4631static void
4633 std::optional<mlir::ArrayAttr> deviceTypes) {
4634
4635 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4636 auto deviceTypeAttr =
4637 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4638 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4639 return;
4640 }
4641
4642 if (!hasDeviceTypeValues(deviceTypes))
4643 return;
4644
4645 p << "([";
4646 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4647 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4648 p << dTypeAttr;
4649 });
4650 p << "])";
4651}
4652
4653bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4654
4655bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4656 return hasDeviceType(getWorker(), deviceType);
4657}
4658
4659bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4660
4661bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4662 return hasDeviceType(getVector(), deviceType);
4663}
4664
4665bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4666
4667bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4668 return hasDeviceType(getSeq(), deviceType);
4669}
4670
4671std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4672RoutineOp::getBindNameValue() {
4673 return getBindNameValue(mlir::acc::DeviceType::None);
4674}
4675
4676std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4677RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4678 if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
4679 !hasDeviceTypeValues(getBindStrNameDeviceType())) {
4680 return std::nullopt;
4681 }
4682
4683 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4684 auto attr = (*getBindIdName())[*pos];
4685 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4686 assert(symbolRefAttr && "expected SymbolRef");
4687 return symbolRefAttr;
4688 }
4689
4690 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4691 auto attr = (*getBindStrName())[*pos];
4692 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4693 assert(stringAttr && "expected String");
4694 return stringAttr;
4695 }
4696
4697 return std::nullopt;
4698}
4699
4700bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4701
4702bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4703 return hasDeviceType(getGang(), deviceType);
4704}
4705
4706std::optional<int64_t> RoutineOp::getGangDimValue() {
4707 return getGangDimValue(mlir::acc::DeviceType::None);
4708}
4709
4710std::optional<int64_t>
4711RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4712 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4713 return std::nullopt;
4714 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4715 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4716 return intAttr.getInt();
4717 }
4718 return std::nullopt;
4719}
4720
4721void RoutineOp::addSeq(MLIRContext *context,
4722 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4723 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4724 effectiveDeviceTypes));
4725}
4726
4727void RoutineOp::addVector(MLIRContext *context,
4728 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4729 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4730 effectiveDeviceTypes));
4731}
4732
4733void RoutineOp::addWorker(MLIRContext *context,
4734 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4735 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4736 effectiveDeviceTypes));
4737}
4738
4739void RoutineOp::addGang(MLIRContext *context,
4740 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4741 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4742 effectiveDeviceTypes));
4743}
4744
4745void RoutineOp::addGang(MLIRContext *context,
4746 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4747 uint64_t val) {
4750
4751 if (getGangDimAttr())
4752 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4753 if (getGangDimDeviceTypeAttr())
4754 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4755
4756 assert(dimValues.size() == deviceTypes.size());
4757
4758 if (effectiveDeviceTypes.empty()) {
4759 dimValues.push_back(
4760 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4761 deviceTypes.push_back(
4762 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4763 } else {
4764 for (DeviceType dt : effectiveDeviceTypes) {
4765 dimValues.push_back(
4766 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4767 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4768 }
4769 }
4770 assert(dimValues.size() == deviceTypes.size());
4771
4772 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4773 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4774}
4775
4776void RoutineOp::addBindStrName(MLIRContext *context,
4777 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4778 mlir::StringAttr val) {
4779 unsigned before = getBindStrNameDeviceTypeAttr()
4780 ? getBindStrNameDeviceTypeAttr().size()
4781 : 0;
4782
4783 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4784 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4785 unsigned after = getBindStrNameDeviceTypeAttr().size();
4786
4788 if (getBindStrNameAttr())
4789 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4790 for (unsigned i = 0; i < after - before; ++i)
4791 vals.push_back(val);
4792
4793 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4794}
4795
4796void RoutineOp::addBindIDName(MLIRContext *context,
4797 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4798 mlir::SymbolRefAttr val) {
4799 unsigned before =
4800 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4801
4802 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4803 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4804 unsigned after = getBindIdNameDeviceTypeAttr().size();
4805
4807 if (getBindIdNameAttr())
4808 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4809 for (unsigned i = 0; i < after - before; ++i)
4810 vals.push_back(val);
4811
4812 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4813}
4814
4815//===----------------------------------------------------------------------===//
4816// InitOp
4817//===----------------------------------------------------------------------===//
4818
4819LogicalResult acc::InitOp::verify() {
4820 Operation *currOp = *this;
4821 while ((currOp = currOp->getParentOp()))
4822 if (isComputeOperation(currOp))
4823 return emitOpError("cannot be nested in a compute operation");
4824 return success();
4825}
4826
4827void acc::InitOp::addDeviceType(MLIRContext *context,
4828 mlir::acc::DeviceType deviceType) {
4830 if (getDeviceTypesAttr())
4831 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4832
4833 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4834 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4835}
4836
4837//===----------------------------------------------------------------------===//
4838// ShutdownOp
4839//===----------------------------------------------------------------------===//
4840
4841LogicalResult acc::ShutdownOp::verify() {
4842 Operation *currOp = *this;
4843 while ((currOp = currOp->getParentOp()))
4844 if (isComputeOperation(currOp))
4845 return emitOpError("cannot be nested in a compute operation");
4846 return success();
4847}
4848
4849void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4850 mlir::acc::DeviceType deviceType) {
4852 if (getDeviceTypesAttr())
4853 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4854
4855 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4856 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4857}
4858
4859//===----------------------------------------------------------------------===//
4860// SetOp
4861//===----------------------------------------------------------------------===//
4862
4863LogicalResult acc::SetOp::verify() {
4864 Operation *currOp = *this;
4865 while ((currOp = currOp->getParentOp()))
4866 if (isComputeOperation(currOp))
4867 return emitOpError("cannot be nested in a compute operation");
4868 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4869 return emitOpError("at least one default_async, device_num, or device_type "
4870 "operand must appear");
4871 return success();
4872}
4873
4874//===----------------------------------------------------------------------===//
4875// UpdateOp
4876//===----------------------------------------------------------------------===//
4877
4878LogicalResult acc::UpdateOp::verify() {
4879 // At least one of host or device should have a value.
4880 if (getDataClauseOperands().empty())
4881 return emitError("at least one value must be present in dataOperands");
4882
4884 getAsyncOperandsDeviceTypeAttr(),
4885 "async")))
4886 return failure();
4887
4889 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4890 getWaitOperandsDeviceTypeAttr(), "wait")))
4891 return failure();
4892
4894 return failure();
4895
4896 for (mlir::Value operand : getDataClauseOperands())
4897 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4898 operand.getDefiningOp()))
4899 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4900 "as defining op");
4901
4902 return success();
4903}
4904
4905unsigned UpdateOp::getNumDataOperands() {
4906 return getDataClauseOperands().size();
4907}
4908
4909Value UpdateOp::getDataOperand(unsigned i) {
4910 unsigned numOptional = getAsyncOperands().size();
4911 numOptional += getIfCond() ? 1 : 0;
4912 return getOperand(getWaitOperands().size() + numOptional + i);
4913}
4914
4915void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
4916 MLIRContext *context) {
4917 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
4918}
4919
4920bool UpdateOp::hasAsyncOnly() {
4921 return hasAsyncOnly(mlir::acc::DeviceType::None);
4922}
4923
4924bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4925 return hasDeviceType(getAsyncOnly(), deviceType);
4926}
4927
4928mlir::Value UpdateOp::getAsyncValue() {
4929 return getAsyncValue(mlir::acc::DeviceType::None);
4930}
4931
4932mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4934 return {};
4935
4936 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4937 return getAsyncOperands()[*pos];
4938
4939 return {};
4940}
4941
4942bool UpdateOp::hasWaitOnly() {
4943 return hasWaitOnly(mlir::acc::DeviceType::None);
4944}
4945
4946bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4947 return hasDeviceType(getWaitOnly(), deviceType);
4948}
4949
4950mlir::Operation::operand_range UpdateOp::getWaitValues() {
4951 return getWaitValues(mlir::acc::DeviceType::None);
4952}
4953
4955UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4957 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4958 getHasWaitDevnum(), deviceType);
4959}
4960
4961mlir::Value UpdateOp::getWaitDevnum() {
4962 return getWaitDevnum(mlir::acc::DeviceType::None);
4963}
4964
4965mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4966 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4967 getWaitOperandsSegments(), getHasWaitDevnum(),
4968 deviceType);
4969}
4970
4971void UpdateOp::addAsyncOnly(MLIRContext *context,
4972 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4973 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4974 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4975}
4976
4977void UpdateOp::addAsyncOperand(
4978 MLIRContext *context, mlir::Value newValue,
4979 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4980 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4981 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4982 getAsyncOperandsMutable()));
4983}
4984
4985void UpdateOp::addWaitOnly(MLIRContext *context,
4986 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4987 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4988 effectiveDeviceTypes));
4989}
4990
4991void UpdateOp::addWaitOperands(
4992 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4993 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4994
4996 if (getWaitOperandsSegments())
4997 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4998
4999 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
5000 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
5001 getWaitOperandsMutable(), segments));
5002 setWaitOperandsSegments(segments);
5003
5005 if (getHasWaitDevnumAttr())
5006 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
5007 hasDevnums.insert(
5008 hasDevnums.end(),
5009 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
5010 mlir::BoolAttr::get(context, hasDevnum));
5011 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5012}
5013
5014//===----------------------------------------------------------------------===//
5015// WaitOp
5016//===----------------------------------------------------------------------===//
5017
5018LogicalResult acc::WaitOp::verify() {
5019 // The async attribute represent the async clause without value. Therefore the
5020 // attribute and operand cannot appear at the same time.
5021 if (getAsyncOperand() && getAsync())
5022 return emitError("async attribute cannot appear with asyncOperand");
5023
5024 if (getWaitDevnum() && getWaitOperands().empty())
5025 return emitError("wait_devnum cannot appear without waitOperands");
5026
5027 return success();
5028}
5029
5030#define GET_OP_CLASSES
5031#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5032
5033#define GET_ATTRDEF_CLASSES
5034#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5035
5036#define GET_TYPEDEF_CLASSES
5037#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5038
5039//===----------------------------------------------------------------------===//
5040// acc dialect utilities
5041//===----------------------------------------------------------------------===//
5042
5045 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
5047 accDataClauseOp)
5048 .Case<ACC_DATA_ENTRY_OPS>(
5049 [&](auto entry) { return entry.getVarPtr(); })
5050 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5051 [&](auto exit) { return exit.getVarPtr(); })
5052 .Default([&](mlir::Operation *) {
5054 })};
5055 return varPtr;
5056}
5057
5059 auto varPtr{
5061 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
5062 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5063 return varPtr;
5064}
5065
5067 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
5068 .Case<ACC_DATA_ENTRY_OPS>(
5069 [&](auto entry) { return entry.getVarType(); })
5070 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5071 [&](auto exit) { return exit.getVarType(); })
5072 .Default([&](mlir::Operation *) { return mlir::Type(); })};
5073 return varType;
5074}
5075
5078 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
5080 accDataClauseOp)
5081 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5082 [&](auto dataClause) { return dataClause.getAccPtr(); })
5083 .Default([&](mlir::Operation *) {
5085 })};
5086 return accPtr;
5087}
5088
5090 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
5092 [&](auto dataClause) { return dataClause.getAccVar(); })
5093 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5094 return accPtr;
5095}
5096
5098 auto varPtrPtr{
5100 .Case<ACC_DATA_ENTRY_OPS>(
5101 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
5102 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5103 return varPtrPtr;
5104}
5105
5110 accDataClauseOp)
5111 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5113 dataClause.getBounds().begin(), dataClause.getBounds().end());
5114 })
5115 .Default([&](mlir::Operation *) {
5117 })};
5118 return bounds;
5119}
5120
5124 accDataClauseOp)
5125 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5127 dataClause.getAsyncOperands().begin(),
5128 dataClause.getAsyncOperands().end());
5129 })
5130 .Default([&](mlir::Operation *) {
5132 });
5133}
5134
5135mlir::ArrayAttr
5138 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5139 return dataClause.getAsyncOperandsDeviceTypeAttr();
5140 })
5141 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5142}
5143
5144mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
5147 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
5148 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5149}
5150
5151std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
5152 auto name{
5154 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
5155 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
5156 return {};
5157 })};
5158 return name;
5159}
5160
5161std::optional<mlir::acc::DataClause>
5163 auto dataClause{
5165 accDataEntryOp)
5166 .Case<ACC_DATA_ENTRY_OPS>(
5167 [&](auto entry) { return entry.getDataClause(); })
5168 .Default([&](mlir::Operation *) { return std::nullopt; })};
5169 return dataClause;
5170}
5171
5173 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
5174 .Case<ACC_DATA_ENTRY_OPS>(
5175 [&](auto entry) { return entry.getImplicit(); })
5176 .Default([&](mlir::Operation *) { return false; })};
5177 return implicit;
5178}
5179
5181 auto dataOperands{
5184 [&](auto entry) { return entry.getDataClauseOperands(); })
5185 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
5186 return dataOperands;
5187}
5188
5191 auto dataOperands{
5194 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
5195 .Default([&](mlir::Operation *) { return nullptr; })};
5196 return dataOperands;
5197}
5198
5199mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
5200 auto recipe{
5202 .Case<ACC_DATA_ENTRY_OPS>(
5203 [&](auto entry) { return entry.getRecipeAttr(); })
5204 .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
5205 return recipe;
5206}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4575
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1458
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:3326
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:1945
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4430
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
Definition OpenACC.cpp:819
static bool isComputeOperation(Operation *op)
Definition OpenACC.cpp:1472
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:599
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2441
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
Definition OpenACC.cpp:812
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:752
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:583
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:721
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2452
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2357
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:525
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4632
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:3135
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2692
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:3342
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:4337
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:659
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2646
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:543
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:641
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:675
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1223
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3710
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1900
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2488
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:2027
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1233
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:667
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:730
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:554
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:567
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2227
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:422
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:706
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3741
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:433
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4605
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4514
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2340
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2515
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2631
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2294
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
Definition OpenACC.cpp:1211
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2607
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:793
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:3154
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1680
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2676
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:2271
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
Definition OpenACC.cpp:685
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
Definition OpenACC.cpp:1914
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2588
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:529
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:3281
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2526
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:764
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:619
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:1955
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4396
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2277
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2712
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4484
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
false
Parses a map_entries map type from a string format back into its numeric value.
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:156
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:69
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:46
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:54
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5089
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5058
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5077
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5162
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:5190
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:5107
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:5180
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:5151
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:5172
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
Definition OpenACC.cpp:5199
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:5122
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
Definition OpenACC.h:167
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:5097
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:204
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5144
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5066
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5044
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5136
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.