MLIR 23.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8
15#include "mlir/IR/Builders.h"
19#include "mlir/IR/IRMapping.h"
20#include "mlir/IR/Matchers.h"
22#include "mlir/IR/SymbolTable.h"
23#include "mlir/Support/LLVM.h"
25#include "llvm/ADT/SmallSet.h"
26#include "llvm/ADT/TypeSwitch.h"
27#include "llvm/Support/LogicalResult.h"
28#include <variant>
29
30using namespace mlir;
31using namespace acc;
32
33#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
37#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
38
39namespace {
40
41static bool isScalarLikeType(Type type) {
42 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
43}
44
45/// Helper function to attach the `VarName` attribute to an operation
46/// if a variable name is provided.
47static void attachVarNameAttr(Operation *op, OpBuilder &builder,
48 StringRef varName) {
49 if (!varName.empty()) {
50 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
51 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
52 }
53}
54
55template <typename T>
56struct MemRefPointerLikeModel
57 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
58 Type getElementType(Type pointer) const {
59 return cast<T>(pointer).getElementType();
60 }
61
62 mlir::acc::VariableTypeCategory
63 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
64 Type varType) const {
65 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
66 return mappableTy.getTypeCategory(varPtr);
67 }
68 auto memrefTy = cast<T>(pointer);
69 if (!memrefTy.hasRank()) {
70 // This memref is unranked - aka it could have any rank, including a
71 // rank of 0 which could mean scalar. For now, return uncategorized.
72 return mlir::acc::VariableTypeCategory::uncategorized;
73 }
74
75 if (memrefTy.getRank() == 0) {
76 if (isScalarLikeType(memrefTy.getElementType())) {
77 return mlir::acc::VariableTypeCategory::scalar;
78 }
79 // Zero-rank non-scalar - need further analysis to determine the type
80 // category. For now, return uncategorized.
81 return mlir::acc::VariableTypeCategory::uncategorized;
82 }
83
84 // It has a rank - must be an array.
85 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
86 return mlir::acc::VariableTypeCategory::array;
87 }
88
89 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
90 StringRef varName, Type varType, Value originalVar,
91 bool &needsFree) const {
92 auto memrefTy = cast<MemRefType>(pointer);
93
94 // Check if this is a static memref (all dimensions are known) - if yes
95 // then we can generate an alloca operation.
96 if (memrefTy.hasStaticShape()) {
97 needsFree = false; // alloca doesn't need deallocation
98 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
99 attachVarNameAttr(allocaOp, builder, varName);
100 return allocaOp.getResult();
101 }
102
103 // For dynamic memrefs, extract sizes from the original variable if
104 // provided. Otherwise they cannot be handled.
105 if (originalVar && originalVar.getType() == memrefTy &&
106 memrefTy.hasRank()) {
107 SmallVector<Value> dynamicSizes;
108 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
109 if (memrefTy.isDynamicDim(i)) {
110 // Extract the size of dimension i from the original variable
111 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
112 auto dimSize =
113 memref::DimOp::create(builder, loc, originalVar, indexValue);
114 dynamicSizes.push_back(dimSize);
115 }
116 // Note: We only add dynamic sizes to the dynamicSizes array
117 // Static dimensions are handled automatically by AllocOp
118 }
119 needsFree = true; // alloc needs deallocation
120 auto allocOp =
121 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
122 attachVarNameAttr(allocOp, builder, varName);
123 return allocOp.getResult();
124 }
125
126 // TODO: Unranked not yet supported.
127 return {};
128 }
129
130 bool genFree(Type pointer, OpBuilder &builder, Location loc,
131 TypedValue<PointerLikeType> varToFree, Value allocRes,
132 Type varType) const {
133 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
134 // Use allocRes if provided to determine the allocation type
135 Value valueToInspect = allocRes ? allocRes : memrefValue;
136
137 // Walk through casts to find the original allocation
138 Value currentValue = valueToInspect;
139 Operation *originalAlloc = nullptr;
140
141 // Follow the chain of operations to find the original allocation
142 // even if a casted result is provided.
143 while (currentValue) {
144 if (auto *definingOp = currentValue.getDefiningOp()) {
145 // Check if this is an allocation operation
146 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
147 originalAlloc = definingOp;
148 break;
149 }
150
151 // Check if this is a cast operation we can look through
152 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
153 currentValue = castOp.getSource();
154 continue;
155 }
156
157 // Check for other cast-like operations
158 if (auto reinterpretCastOp =
159 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
160 currentValue = reinterpretCastOp.getSource();
161 continue;
162 }
163
164 // If we can't look through this operation, stop
165 break;
166 }
167 // This is a block argument or similar - can't trace further.
168 break;
169 }
170
171 if (originalAlloc) {
172 if (isa<memref::AllocaOp>(originalAlloc)) {
173 // This is an alloca - no dealloc needed, but return true (success)
174 return true;
175 }
176 if (isa<memref::AllocOp>(originalAlloc)) {
177 // This is an alloc - generate dealloc on varToFree
178 memref::DeallocOp::create(builder, loc, memrefValue);
179 return true;
180 }
181 }
182 }
183
184 return false;
185 }
186
187 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
188 TypedValue<PointerLikeType> destination,
189 TypedValue<PointerLikeType> source, Type varType) const {
190 // Generate a copy operation between two memrefs
191 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
192 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
193
194 // As per memref documentation, source and destination must have same
195 // element type and shape in order to be compatible. We do not want to fail
196 // with an IR verification error - thus check that before generating the
197 // copy operation.
198 if (destMemref && srcMemref &&
199 destMemref.getType().getElementType() ==
200 srcMemref.getType().getElementType() &&
201 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
202 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
203 return true;
204 }
205
206 return false;
207 }
208
209 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
211 Type valueType) const {
212 // Load from a memref - only valid for scalar memrefs (rank 0).
213 // This is because the address computation for memrefs is part of the load
214 // (and not computed separately), but the API does not have arguments for
215 // indexing.
216 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
217 if (!memrefValue)
218 return {};
219
220 auto memrefTy = memrefValue.getType();
221
222 // Only load from scalar memrefs (rank 0)
223 if (memrefTy.getRank() != 0)
224 return {};
225
226 return memref::LoadOp::create(builder, loc, memrefValue);
227 }
228
229 bool genStore(Type pointer, OpBuilder &builder, Location loc,
230 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
231 // Store to a memref - only valid for scalar memrefs (rank 0)
232 // This is because the address computation for memrefs is part of the store
233 // (and not computed separately), but the API does not have arguments for
234 // indexing.
235 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
236 if (!memrefValue)
237 return false;
238
239 auto memrefTy = memrefValue.getType();
240
241 // Only store to scalar memrefs (rank 0)
242 if (memrefTy.getRank() != 0)
243 return false;
244
245 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
246 return true;
247 }
248
249 bool isDeviceData(Type pointer, Value var) const {
250 auto memrefTy = cast<T>(pointer);
251 Attribute memSpace = memrefTy.getMemorySpace();
252 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
253 }
254};
255
256struct LLVMPointerPointerLikeModel
257 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
258 LLVM::LLVMPointerType> {
259 Type getElementType(Type pointer) const { return Type(); }
260
261 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
263 Type valueType) const {
264 // For LLVM pointers, we need the valueType to determine what to load
265 if (!valueType)
266 return {};
267
268 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
269 }
270
271 bool genStore(Type pointer, OpBuilder &builder, Location loc,
272 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
273 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
274 return true;
275 }
276};
277
278struct MemrefAddressOfGlobalModel
279 : public AddressOfGlobalOpInterface::ExternalModel<
280 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
281 SymbolRefAttr getSymbol(Operation *op) const {
282 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
283 return getGlobalOp.getNameAttr();
284 }
285};
286
287struct MemrefGlobalVariableModel
288 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
289 memref::GlobalOp> {
290 bool isConstant(Operation *op) const {
291 auto globalOp = cast<memref::GlobalOp>(op);
292 return globalOp.getConstant();
293 }
294
295 Region *getInitRegion(Operation *op) const {
296 // GlobalOp uses attributes for initialization, not regions
297 return nullptr;
298 }
299
300 bool isDeviceData(Operation *op) const {
301 auto globalOp = cast<memref::GlobalOp>(op);
302 Attribute memSpace = globalOp.getType().getMemorySpace();
303 return isa_and_nonnull<gpu::AddressSpaceAttr>(memSpace);
304 }
305};
306
307struct GPULaunchOffloadRegionModel
308 : public acc::OffloadRegionOpInterface::ExternalModel<
309 GPULaunchOffloadRegionModel, gpu::LaunchOp> {
310 mlir::Region &getOffloadRegion(mlir::Operation *op) const {
311 return cast<gpu::LaunchOp>(op).getBody();
312 }
313};
314
315/// Helper function for any of the times we need to modify an ArrayAttr based on
316/// a device type list. Returns a new ArrayAttr with all of the
317/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
318/// list is empty).
319mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
320 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
321 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
323 if (existingDeviceTypes)
324 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
325
326 if (newDeviceTypes.empty())
327 deviceTypes.push_back(
328 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
329
330 for (DeviceType dt : newDeviceTypes)
331 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
332
333 return mlir::ArrayAttr::get(context, deviceTypes);
334}
335
336/// Helper function for any of the times we need to add operands that are
337/// affected by a device type list. Returns a new ArrayAttr with all of the
338/// existingDeviceTypes, plus the effective new ones (or an added none, if the
339/// new list is empty). Additionally, adds the arguments to the argCollection
340/// the correct number of times. This will also update a 'segments' array, even
341/// if it won't be used.
342mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
343 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
344 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
345 mlir::MutableOperandRange argCollection,
346 llvm::SmallVector<int32_t> &segments) {
348 if (existingDeviceTypes)
349 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
350
351 if (newDeviceTypes.empty()) {
352 argCollection.append(arguments);
353 segments.push_back(arguments.size());
354 deviceTypes.push_back(
355 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
356 }
357
358 for (DeviceType dt : newDeviceTypes) {
359 argCollection.append(arguments);
360 segments.push_back(arguments.size());
361 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
362 }
363
364 return mlir::ArrayAttr::get(context, deviceTypes);
365}
366
367/// Overload for when the 'segments' aren't needed.
368mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
369 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
370 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
371 mlir::MutableOperandRange argCollection) {
373 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
374 newDeviceTypes, arguments,
375 argCollection, segments);
376}
377} // namespace
378
379//===----------------------------------------------------------------------===//
380// OpenACC operations
381//===----------------------------------------------------------------------===//
382
383void OpenACCDialect::initialize() {
384 addOperations<
385#define GET_OP_LIST
386#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
387 >();
388 addAttributes<
389#define GET_ATTRDEF_LIST
390#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
391 >();
392 addTypes<
393#define GET_TYPEDEF_LIST
394#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
395 >();
396
397 // By attaching interfaces here, we make the OpenACC dialect dependent on
398 // the other dialects. This is probably better than having dialects like LLVM
399 // and memref be dependent on OpenACC.
400 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
401 *getContext());
402 UnrankedMemRefType::attachInterface<
403 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
404 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
405 *getContext());
406
407 // Attach operation interfaces
408 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
409 *getContext());
410 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
411 gpu::LaunchOp::attachInterface<GPULaunchOffloadRegionModel>(*getContext());
412}
413
414//===----------------------------------------------------------------------===//
415// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
416// acc.kernel_environment / acc.data / acc.host_data / acc.loop
417//===----------------------------------------------------------------------===//
418
419/// Generic helper for single-region OpenACC ops that execute their body once
420/// and then return to the parent operation with their results (if any).
421static void
423 RegionBranchPoint point,
425 if (point.isParent()) {
426 regions.push_back(RegionSuccessor(&region));
427 return;
428 }
429
430 regions.push_back(RegionSuccessor::parent());
431}
432
434 RegionSuccessor successor) {
435 return successor.isParent() ? ValueRange(op->getResults()) : ValueRange();
436}
437
438void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
440 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
441 regions);
442}
443
444ValueRange KernelsOp::getSuccessorInputs(RegionSuccessor successor) {
445 return getSingleRegionSuccessorInputs(getOperation(), successor);
446}
447
448void ParallelOp::getSuccessorRegions(
450 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
451 regions);
452}
453
454ValueRange ParallelOp::getSuccessorInputs(RegionSuccessor successor) {
455 return getSingleRegionSuccessorInputs(getOperation(), successor);
456}
457
458void SerialOp::getSuccessorRegions(RegionBranchPoint point,
460 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
461 regions);
462}
463
464ValueRange SerialOp::getSuccessorInputs(RegionSuccessor successor) {
465 return getSingleRegionSuccessorInputs(getOperation(), successor);
466}
467
468void DataOp::getSuccessorRegions(RegionBranchPoint point,
470 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
471 regions);
472}
473
474ValueRange DataOp::getSuccessorInputs(RegionSuccessor successor) {
475 return getSingleRegionSuccessorInputs(getOperation(), successor);
476}
477
478void HostDataOp::getSuccessorRegions(
480 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
481 regions);
482}
483
484ValueRange HostDataOp::getSuccessorInputs(RegionSuccessor successor) {
485 return getSingleRegionSuccessorInputs(getOperation(), successor);
486}
487
488void LoopOp::getSuccessorRegions(RegionBranchPoint point,
490 // Unstructured loops: the body may contain arbitrary CFG and early exits.
491 // At the RegionBranch level, only model entry into the body and exit to the
492 // parent; any backedges are represented inside the region CFG.
493 if (getUnstructured()) {
494 if (point.isParent()) {
495 regions.push_back(RegionSuccessor(&getRegion()));
496 return;
497 }
498 regions.push_back(RegionSuccessor::parent());
499 return;
500 }
501
502 // Structured loops: model a loop-shaped region graph similar to scf.for.
503 regions.push_back(RegionSuccessor(&getRegion()));
504 regions.push_back(RegionSuccessor::parent());
505}
506
507ValueRange LoopOp::getSuccessorInputs(RegionSuccessor successor) {
508 return getSingleRegionSuccessorInputs(getOperation(), successor);
509}
510
511//===----------------------------------------------------------------------===//
512// RegionBranchTerminatorOpInterface
513//===----------------------------------------------------------------------===//
514
516TerminatorOp::getMutableSuccessorOperands(RegionSuccessor /*point*/) {
517 // `acc.terminator` does not forward operands.
518 return MutableOperandRange(getOperation(), /*start=*/0, /*length=*/0);
519}
520
521//===----------------------------------------------------------------------===//
522// device_type support helpers
523//===----------------------------------------------------------------------===//
524
525static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
526 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
527}
528
529static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
530 mlir::acc::DeviceType deviceType) {
531 if (!hasDeviceTypeValues(arrayAttr))
532 return false;
533
534 for (auto attr : *arrayAttr) {
535 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
536 if (deviceTypeAttr.getValue() == deviceType)
537 return true;
538 }
539
540 return false;
541}
542
544 std::optional<mlir::ArrayAttr> deviceTypes) {
545 if (!hasDeviceTypeValues(deviceTypes))
546 return;
547
548 p << "[";
549 llvm::interleaveComma(*deviceTypes, p,
550 [&](mlir::Attribute attr) { p << attr; });
551 p << "]";
552}
553
554static std::optional<unsigned> findSegment(ArrayAttr segments,
555 mlir::acc::DeviceType deviceType) {
556 unsigned segmentIdx = 0;
557 for (auto attr : segments) {
558 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
559 if (deviceTypeAttr.getValue() == deviceType)
560 return std::make_optional(segmentIdx);
561 ++segmentIdx;
562 }
563 return std::nullopt;
564}
565
567getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
569 std::optional<llvm::ArrayRef<int32_t>> segments,
570 mlir::acc::DeviceType deviceType) {
571 if (!arrayAttr)
572 return range.take_front(0);
573 if (auto pos = findSegment(*arrayAttr, deviceType)) {
574 int32_t nbOperandsBefore = 0;
575 for (unsigned i = 0; i < *pos; ++i)
576 nbOperandsBefore += (*segments)[i];
577 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
578 }
579 return range.take_front(0);
580}
581
582static mlir::Value
583getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
585 std::optional<llvm::ArrayRef<int32_t>> segments,
586 std::optional<mlir::ArrayAttr> hasWaitDevnum,
587 mlir::acc::DeviceType deviceType) {
588 if (!hasDeviceTypeValues(deviceTypeAttr))
589 return {};
590 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
591 if (hasWaitDevnum->getValue()[*pos])
592 return getValuesFromSegments(deviceTypeAttr, operands, segments,
593 deviceType)
594 .front();
595 return {};
596}
597
599getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
601 std::optional<llvm::ArrayRef<int32_t>> segments,
602 std::optional<mlir::ArrayAttr> hasWaitDevnum,
603 mlir::acc::DeviceType deviceType) {
604 auto range =
605 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
606 if (range.empty())
607 return range;
608 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
609 if (hasWaitDevnum && *hasWaitDevnum) {
610 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
611 if (boolAttr.getValue())
612 return range.drop_front(1); // first value is devnum
613 }
614 }
615 return range;
616}
617
618template <typename Op>
619static LogicalResult checkWaitAndAsyncConflict(Op op) {
620 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
621 ++dtypeInt) {
622 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
623
624 // The asyncOnly attribute represent the async clause without value.
625 // Therefore the attribute and operand cannot appear at the same time.
626 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
627 op.hasAsyncOnly(dtype))
628 return op.emitError(
629 "asyncOnly attribute cannot appear with asyncOperand");
630
631 // The wait attribute represent the wait clause without values. Therefore
632 // the attribute and operands cannot appear at the same time.
633 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
634 op.hasWaitOnly(dtype))
635 return op.emitError("wait attribute cannot appear with waitOperands");
636 }
637 return success();
638}
639
640template <typename Op>
641static LogicalResult checkVarAndVarType(Op op) {
642 if (!op.getVar())
643 return op.emitError("must have var operand");
644
645 // A variable must have a type that is either pointer-like or mappable.
646 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
647 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
648 return op.emitError("var must be mappable or pointer-like");
649
650 // When it is a pointer-like type, the varType must capture the target type.
651 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
652 op.getVarType() == op.getVar().getType())
653 return op.emitError("varType must capture the element type of var");
654
655 return success();
656}
657
658template <typename Op>
659static LogicalResult checkVarAndAccVar(Op op) {
660 if (op.getVar().getType() != op.getAccVar().getType())
661 return op.emitError("input and output types must match");
662
663 return success();
664}
665
666template <typename Op>
667static LogicalResult checkNoModifier(Op op) {
668 if (op.getModifiers() != acc::DataClauseModifier::none)
669 return op.emitError("no data clause modifiers are allowed");
670 return success();
671}
672
673template <typename Op>
674static LogicalResult
675checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
676 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
677 return op.emitError(
678 "invalid data clause modifiers: " +
679 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
680
681 return success();
682}
683
684template <typename OpT, typename RecipeOpT>
685static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
686 // Mappable types do not need a recipe because it is possible to generate one
687 // from its API. Reject reductions though because no API is available for them
688 // at this time.
689 if (mlir::acc::isMappableType(op.getVar().getType()) &&
690 !std::is_same_v<OpT, acc::ReductionOp>)
691 return success();
692
693 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
694 if (!operandRecipe)
695 return op->emitOpError() << "recipe expected for " << operandName;
696
697 auto decl =
699 if (!decl)
700 return op->emitOpError()
701 << "expected symbol reference " << operandRecipe << " to point to a "
702 << operandName << " declaration";
703 return success();
704}
705
706static ParseResult parseVar(mlir::OpAsmParser &parser,
708 // Either `var` or `varPtr` keyword is required.
709 if (failed(parser.parseOptionalKeyword("varPtr"))) {
710 if (failed(parser.parseKeyword("var")))
711 return failure();
712 }
713 if (failed(parser.parseLParen()))
714 return failure();
715 if (failed(parser.parseOperand(var)))
716 return failure();
717
718 return success();
719}
720
722 mlir::Value var) {
723 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
724 p << "varPtr(";
725 else
726 p << "var(";
727 p.printOperand(var);
728}
729
730static ParseResult parseAccVar(mlir::OpAsmParser &parser,
732 mlir::Type &accVarType) {
733 // Either `accVar` or `accPtr` keyword is required.
734 if (failed(parser.parseOptionalKeyword("accPtr"))) {
735 if (failed(parser.parseKeyword("accVar")))
736 return failure();
737 }
738 if (failed(parser.parseLParen()))
739 return failure();
740 if (failed(parser.parseOperand(var)))
741 return failure();
742 if (failed(parser.parseColon()))
743 return failure();
744 if (failed(parser.parseType(accVarType)))
745 return failure();
746 if (failed(parser.parseRParen()))
747 return failure();
748
749 return success();
750}
751
753 mlir::Value accVar, mlir::Type accVarType) {
754 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
755 p << "accPtr(";
756 else
757 p << "accVar(";
758 p.printOperand(accVar);
759 p << " : ";
760 p.printType(accVarType);
761 p << ")";
762}
763
764static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
765 mlir::Type &varPtrType,
766 mlir::TypeAttr &varTypeAttr) {
767 if (failed(parser.parseType(varPtrType)))
768 return failure();
769 if (failed(parser.parseRParen()))
770 return failure();
771
772 if (succeeded(parser.parseOptionalKeyword("varType"))) {
773 if (failed(parser.parseLParen()))
774 return failure();
775 mlir::Type varType;
776 if (failed(parser.parseType(varType)))
777 return failure();
778 varTypeAttr = mlir::TypeAttr::get(varType);
779 if (failed(parser.parseRParen()))
780 return failure();
781 } else {
782 // Set `varType` from the element type of the type of `varPtr`.
783 if (auto ptrTy = dyn_cast<acc::PointerLikeType>(varPtrType)) {
784 Type elementType = ptrTy.getElementType();
785 // Opaque pointers (e.g. !llvm.ptr) have no element type; fall back to
786 // using varPtrType itself so that the attribute is always valid.
787 varTypeAttr = mlir::TypeAttr::get(elementType ? elementType : varPtrType);
788 } else {
789 varTypeAttr = mlir::TypeAttr::get(varPtrType);
790 }
791 }
792
793 return success();
794}
795
797 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
798 p.printType(varPtrType);
799 p << ")";
800
801 // Print the `varType` only if it differs from the element type of
802 // `varPtr`'s type.
803 mlir::Type varType = varTypeAttr.getValue();
804 mlir::Type typeToCheckAgainst =
805 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
806 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
807 : varPtrType;
808 // Opaque pointers (e.g. !llvm.ptr) have no element type; use varPtrType as
809 // the baseline so that the inferred varType is not redundantly printed.
810 if (!typeToCheckAgainst)
811 typeToCheckAgainst = varPtrType;
812 if (typeToCheckAgainst != varType) {
813 p << " varType(";
814 p.printType(varType);
815 p << ")";
816 }
817}
818
819static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
820 mlir::SymbolRefAttr &recipeAttr) {
821 if (failed(parser.parseAttribute(recipeAttr)))
822 return failure();
823 return success();
824}
825
827 mlir::SymbolRefAttr recipeAttr) {
828 p << recipeAttr;
829}
830
831//===----------------------------------------------------------------------===//
832// DataBoundsOp
833//===----------------------------------------------------------------------===//
834LogicalResult acc::DataBoundsOp::verify() {
835 auto extent = getExtent();
836 auto upperbound = getUpperbound();
837 if (!extent && !upperbound)
838 return emitError("expected extent or upperbound.");
839 return success();
840}
841
842//===----------------------------------------------------------------------===//
843// PrivateOp
844//===----------------------------------------------------------------------===//
845LogicalResult acc::PrivateOp::verify() {
846 if (getDataClause() != acc::DataClause::acc_private)
847 return emitError(
848 "data clause associated with private operation must match its intent");
849 if (failed(checkVarAndVarType(*this)))
850 return failure();
851 if (failed(checkNoModifier(*this)))
852 return failure();
853 if (failed(
855 return failure();
856 return success();
857}
858
859//===----------------------------------------------------------------------===//
860// FirstprivateOp
861//===----------------------------------------------------------------------===//
862LogicalResult acc::FirstprivateOp::verify() {
863 if (getDataClause() != acc::DataClause::acc_firstprivate)
864 return emitError("data clause associated with firstprivate operation must "
865 "match its intent");
866 if (failed(checkVarAndVarType(*this)))
867 return failure();
868 if (failed(checkNoModifier(*this)))
869 return failure();
871 *this, "firstprivate")))
872 return failure();
873 return success();
874}
875
876//===----------------------------------------------------------------------===//
877// ReductionOp
878//===----------------------------------------------------------------------===//
879LogicalResult acc::ReductionOp::verify() {
880 if (getDataClause() != acc::DataClause::acc_reduction)
881 return emitError("data clause associated with reduction operation must "
882 "match its intent");
883 if (failed(checkVarAndVarType(*this)))
884 return failure();
885 if (failed(checkNoModifier(*this)))
886 return failure();
888 *this, "reduction")))
889 return failure();
890 return success();
891}
892
893//===----------------------------------------------------------------------===//
894// DevicePtrOp
895//===----------------------------------------------------------------------===//
896LogicalResult acc::DevicePtrOp::verify() {
897 if (getDataClause() != acc::DataClause::acc_deviceptr)
898 return emitError("data clause associated with deviceptr operation must "
899 "match its intent");
900 if (failed(checkVarAndVarType(*this)))
901 return failure();
902 if (failed(checkVarAndAccVar(*this)))
903 return failure();
904 if (failed(checkNoModifier(*this)))
905 return failure();
906 return success();
907}
908
909//===----------------------------------------------------------------------===//
910// PresentOp
911//===----------------------------------------------------------------------===//
912LogicalResult acc::PresentOp::verify() {
913 if (getDataClause() != acc::DataClause::acc_present)
914 return emitError(
915 "data clause associated with present operation must match its intent");
916 if (failed(checkVarAndVarType(*this)))
917 return failure();
918 if (failed(checkVarAndAccVar(*this)))
919 return failure();
920 if (failed(checkNoModifier(*this)))
921 return failure();
922 return success();
923}
924
925//===----------------------------------------------------------------------===//
926// CopyinOp
927//===----------------------------------------------------------------------===//
928LogicalResult acc::CopyinOp::verify() {
929 // Test for all clauses this operation can be decomposed from:
930 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
931 getDataClause() != acc::DataClause::acc_copyin_readonly &&
932 getDataClause() != acc::DataClause::acc_copy &&
933 getDataClause() != acc::DataClause::acc_reduction)
934 return emitError(
935 "data clause associated with copyin operation must match its intent"
936 " or specify original clause this operation was decomposed from");
937 if (failed(checkVarAndVarType(*this)))
938 return failure();
939 if (failed(checkVarAndAccVar(*this)))
940 return failure();
941 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
942 acc::DataClauseModifier::always |
943 acc::DataClauseModifier::capture)))
944 return failure();
945 return success();
946}
947
948bool acc::CopyinOp::isCopyinReadonly() {
949 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
950 acc::bitEnumContainsAny(getModifiers(),
951 acc::DataClauseModifier::readonly);
952}
953
954//===----------------------------------------------------------------------===//
955// CreateOp
956//===----------------------------------------------------------------------===//
957LogicalResult acc::CreateOp::verify() {
958 // Test for all clauses this operation can be decomposed from:
959 if (getDataClause() != acc::DataClause::acc_create &&
960 getDataClause() != acc::DataClause::acc_create_zero &&
961 getDataClause() != acc::DataClause::acc_copyout &&
962 getDataClause() != acc::DataClause::acc_copyout_zero)
963 return emitError(
964 "data clause associated with create operation must match its intent"
965 " or specify original clause this operation was decomposed from");
966 if (failed(checkVarAndVarType(*this)))
967 return failure();
968 if (failed(checkVarAndAccVar(*this)))
969 return failure();
970 // this op is the entry part of copyout, so it also needs to allow all
971 // modifiers allowed on copyout.
972 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
973 acc::DataClauseModifier::always |
974 acc::DataClauseModifier::capture)))
975 return failure();
976 return success();
977}
978
979bool acc::CreateOp::isCreateZero() {
980 // The zero modifier is encoded in the data clause.
981 return getDataClause() == acc::DataClause::acc_create_zero ||
982 getDataClause() == acc::DataClause::acc_copyout_zero ||
983 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
984}
985
986//===----------------------------------------------------------------------===//
987// NoCreateOp
988//===----------------------------------------------------------------------===//
989LogicalResult acc::NoCreateOp::verify() {
990 if (getDataClause() != acc::DataClause::acc_no_create)
991 return emitError("data clause associated with no_create operation must "
992 "match its intent");
993 if (failed(checkVarAndVarType(*this)))
994 return failure();
995 if (failed(checkVarAndAccVar(*this)))
996 return failure();
997 if (failed(checkNoModifier(*this)))
998 return failure();
999 return success();
1000}
1001
1002//===----------------------------------------------------------------------===//
1003// AttachOp
1004//===----------------------------------------------------------------------===//
1005LogicalResult acc::AttachOp::verify() {
1006 if (getDataClause() != acc::DataClause::acc_attach)
1007 return emitError(
1008 "data clause associated with attach operation must match its intent");
1009 if (failed(checkVarAndVarType(*this)))
1010 return failure();
1011 if (failed(checkVarAndAccVar(*this)))
1012 return failure();
1013 if (failed(checkNoModifier(*this)))
1014 return failure();
1015 return success();
1016}
1017
1018//===----------------------------------------------------------------------===//
1019// DeclareDeviceResidentOp
1020//===----------------------------------------------------------------------===//
1021
1022LogicalResult acc::DeclareDeviceResidentOp::verify() {
1023 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
1024 return emitError("data clause associated with device_resident operation "
1025 "must match its intent");
1026 if (failed(checkVarAndVarType(*this)))
1027 return failure();
1028 if (failed(checkVarAndAccVar(*this)))
1029 return failure();
1030 if (failed(checkNoModifier(*this)))
1031 return failure();
1032 return success();
1033}
1034
1035//===----------------------------------------------------------------------===//
1036// DeclareLinkOp
1037//===----------------------------------------------------------------------===//
1038
1039LogicalResult acc::DeclareLinkOp::verify() {
1040 if (getDataClause() != acc::DataClause::acc_declare_link)
1041 return emitError(
1042 "data clause associated with link operation must match its intent");
1043 if (failed(checkVarAndVarType(*this)))
1044 return failure();
1045 if (failed(checkVarAndAccVar(*this)))
1046 return failure();
1047 if (failed(checkNoModifier(*this)))
1048 return failure();
1049 return success();
1050}
1051
1052//===----------------------------------------------------------------------===//
1053// CopyoutOp
1054//===----------------------------------------------------------------------===//
1055LogicalResult acc::CopyoutOp::verify() {
1056 // Test for all clauses this operation can be decomposed from:
1057 if (getDataClause() != acc::DataClause::acc_copyout &&
1058 getDataClause() != acc::DataClause::acc_copyout_zero &&
1059 getDataClause() != acc::DataClause::acc_copy &&
1060 getDataClause() != acc::DataClause::acc_reduction)
1061 return emitError(
1062 "data clause associated with copyout operation must match its intent"
1063 " or specify original clause this operation was decomposed from");
1064 if (!getVar() || !getAccVar())
1065 return emitError("must have both host and device pointers");
1066 if (failed(checkVarAndVarType(*this)))
1067 return failure();
1068 if (failed(checkVarAndAccVar(*this)))
1069 return failure();
1070 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1071 acc::DataClauseModifier::always |
1072 acc::DataClauseModifier::capture)))
1073 return failure();
1074 return success();
1075}
1076
1077bool acc::CopyoutOp::isCopyoutZero() {
1078 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1079 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1080}
1081
1082//===----------------------------------------------------------------------===//
1083// DeleteOp
1084//===----------------------------------------------------------------------===//
1085LogicalResult acc::DeleteOp::verify() {
1086 // Test for all clauses this operation can be decomposed from:
1087 if (getDataClause() != acc::DataClause::acc_delete &&
1088 getDataClause() != acc::DataClause::acc_create &&
1089 getDataClause() != acc::DataClause::acc_create_zero &&
1090 getDataClause() != acc::DataClause::acc_copyin &&
1091 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1092 getDataClause() != acc::DataClause::acc_present &&
1093 getDataClause() != acc::DataClause::acc_no_create &&
1094 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1095 getDataClause() != acc::DataClause::acc_declare_link)
1096 return emitError(
1097 "data clause associated with delete operation must match its intent"
1098 " or specify original clause this operation was decomposed from");
1099 if (!getAccVar())
1100 return emitError("must have device pointer");
1101 // This op is the exit part of copyin and create - thus allow all modifiers
1102 // allowed on either case.
1103 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1104 acc::DataClauseModifier::readonly |
1105 acc::DataClauseModifier::always |
1106 acc::DataClauseModifier::capture)))
1107 return failure();
1108 return success();
1109}
1110
1111//===----------------------------------------------------------------------===//
1112// DetachOp
1113//===----------------------------------------------------------------------===//
1114LogicalResult acc::DetachOp::verify() {
1115 // Test for all clauses this operation can be decomposed from:
1116 if (getDataClause() != acc::DataClause::acc_detach &&
1117 getDataClause() != acc::DataClause::acc_attach)
1118 return emitError(
1119 "data clause associated with detach operation must match its intent"
1120 " or specify original clause this operation was decomposed from");
1121 if (!getAccVar())
1122 return emitError("must have device pointer");
1123 if (failed(checkNoModifier(*this)))
1124 return failure();
1125 return success();
1126}
1127
1128//===----------------------------------------------------------------------===//
1129// HostOp
1130//===----------------------------------------------------------------------===//
1131LogicalResult acc::UpdateHostOp::verify() {
1132 // Test for all clauses this operation can be decomposed from:
1133 if (getDataClause() != acc::DataClause::acc_update_host &&
1134 getDataClause() != acc::DataClause::acc_update_self)
1135 return emitError(
1136 "data clause associated with host operation must match its intent"
1137 " or specify original clause this operation was decomposed from");
1138 if (!getVar() || !getAccVar())
1139 return emitError("must have both host and device pointers");
1140 if (failed(checkVarAndVarType(*this)))
1141 return failure();
1142 if (failed(checkVarAndAccVar(*this)))
1143 return failure();
1144 if (failed(checkNoModifier(*this)))
1145 return failure();
1146 return success();
1147}
1148
1149//===----------------------------------------------------------------------===//
1150// DeviceOp
1151//===----------------------------------------------------------------------===//
1152LogicalResult acc::UpdateDeviceOp::verify() {
1153 // Test for all clauses this operation can be decomposed from:
1154 if (getDataClause() != acc::DataClause::acc_update_device)
1155 return emitError(
1156 "data clause associated with device operation must match its intent"
1157 " or specify original clause this operation was decomposed from");
1158 if (failed(checkVarAndVarType(*this)))
1159 return failure();
1160 if (failed(checkVarAndAccVar(*this)))
1161 return failure();
1162 if (failed(checkNoModifier(*this)))
1163 return failure();
1164 return success();
1165}
1166
1167//===----------------------------------------------------------------------===//
1168// UseDeviceOp
1169//===----------------------------------------------------------------------===//
1170LogicalResult acc::UseDeviceOp::verify() {
1171 // Test for all clauses this operation can be decomposed from:
1172 if (getDataClause() != acc::DataClause::acc_use_device)
1173 return emitError(
1174 "data clause associated with use_device operation must match its intent"
1175 " or specify original clause this operation was decomposed from");
1176 if (failed(checkVarAndVarType(*this)))
1177 return failure();
1178 if (failed(checkVarAndAccVar(*this)))
1179 return failure();
1180 if (failed(checkNoModifier(*this)))
1181 return failure();
1182 return success();
1183}
1184
1185//===----------------------------------------------------------------------===//
1186// CacheOp
1187//===----------------------------------------------------------------------===//
1188LogicalResult acc::CacheOp::verify() {
1189 // Test for all clauses this operation can be decomposed from:
1190 if (getDataClause() != acc::DataClause::acc_cache &&
1191 getDataClause() != acc::DataClause::acc_cache_readonly)
1192 return emitError(
1193 "data clause associated with cache operation must match its intent"
1194 " or specify original clause this operation was decomposed from");
1195 if (failed(checkVarAndVarType(*this)))
1196 return failure();
1197 if (failed(checkVarAndAccVar(*this)))
1198 return failure();
1199 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
1200 return failure();
1201 return success();
1202}
1203
1204bool acc::CacheOp::isCacheReadonly() {
1205 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1206 acc::bitEnumContainsAny(getModifiers(),
1207 acc::DataClauseModifier::readonly);
1208}
1209
1210//===----------------------------------------------------------------------===//
1211// Data entry/exit operations - getEffects implementations
1212//===----------------------------------------------------------------------===//
1213
1214// This function returns true iff the given operation is enclosed
1215// in any ACC_COMPUTE_CONSTRUCT_OPS operation.
1216// It is quite alike acc::getEnclosingComputeOp() utility,
1217// but we cannot use it here.
1221
1222/// Helper to add an effect on an operand, referenced by its mutable range.
1223template <typename EffectTy>
1226 &effects,
1227 MutableOperandRange operand) {
1228 for (unsigned i = 0, e = operand.size(); i < e; ++i)
1229 effects.emplace_back(EffectTy::get(), &operand[i]);
1230}
1231
1232/// Helper to add an effect on a result value.
1233template <typename EffectTy>
1236 &effects,
1237 Value result) {
1238 effects.emplace_back(EffectTy::get(), mlir::cast<mlir::OpResult>(result));
1239}
1240
1241// PrivateOp: accVar result write.
1242void acc::PrivateOp::getEffects(
1244 &effects) {
1245 // If acc.private is enclosed into a compute operation,
1246 // then it denotes the device side privatization, hence
1247 // it does not access the CurrentDeviceIdResource.
1248 if (!isEnclosedIntoComputeOp(getOperation()))
1249 effects.emplace_back(MemoryEffects::Read::get(),
1251 // TODO: should this be MemoryEffects::Allocate?
1253}
1254
1255// FirstprivateOp: var read, accVar result write.
1256void acc::FirstprivateOp::getEffects(
1258 &effects) {
1259 // If acc.firstprivate is enclosed into a compute operation,
1260 // then it denotes the device side privatization, hence
1261 // it does not access the CurrentDeviceIdResource.
1262 if (!isEnclosedIntoComputeOp(getOperation()))
1263 effects.emplace_back(MemoryEffects::Read::get(),
1265 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1267}
1268
1269// ReductionOp: var read, accVar result write.
1270void acc::ReductionOp::getEffects(
1272 &effects) {
1273 // If acc.reduction is enclosed into a compute operation,
1274 // then it denotes the device side reduction, hence
1275 // it does not access the CurrentDeviceIdResource.
1276 if (!isEnclosedIntoComputeOp(getOperation()))
1277 effects.emplace_back(MemoryEffects::Read::get(),
1279 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1281}
1282
1283// DevicePtrOp: RuntimeCounters read.
1284void acc::DevicePtrOp::getEffects(
1286 &effects) {
1287 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1288 effects.emplace_back(MemoryEffects::Read::get(),
1290}
1291
1292// PresentOp: RuntimeCounters read+write.
1293void acc::PresentOp::getEffects(
1295 &effects) {
1296 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1297 effects.emplace_back(MemoryEffects::Write::get(),
1299 effects.emplace_back(MemoryEffects::Read::get(),
1301}
1302
1303// CopyinOp: RuntimeCounters read+write, var read, accVar result write.
1304void acc::CopyinOp::getEffects(
1306 &effects) {
1307 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1308 effects.emplace_back(MemoryEffects::Write::get(),
1310 effects.emplace_back(MemoryEffects::Read::get(),
1312 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1314}
1315
1316// CreateOp: RuntimeCounters read+write, accVar result write.
1317void acc::CreateOp::getEffects(
1319 &effects) {
1320 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1321 effects.emplace_back(MemoryEffects::Write::get(),
1323 effects.emplace_back(MemoryEffects::Read::get(),
1325 // TODO: should this be MemoryEffects::Allocate?
1327}
1328
1329// NoCreateOp: RuntimeCounters read+write.
1330void acc::NoCreateOp::getEffects(
1332 &effects) {
1333 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1334 effects.emplace_back(MemoryEffects::Write::get(),
1336 effects.emplace_back(MemoryEffects::Read::get(),
1338}
1339
1340// AttachOp: RuntimeCounters read+write, var read.
1341void acc::AttachOp::getEffects(
1343 &effects) {
1344 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1345 effects.emplace_back(MemoryEffects::Write::get(),
1347 effects.emplace_back(MemoryEffects::Read::get(),
1349 // TODO: should we also add MemoryEffects::Write?
1350 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1351}
1352
1353// GetDevicePtrOp: RuntimeCounters read.
1354void acc::GetDevicePtrOp::getEffects(
1356 &effects) {
1357 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1358 effects.emplace_back(MemoryEffects::Read::get(),
1360}
1361
1362// UpdateDeviceOp: var read, accVar result write.
1363void acc::UpdateDeviceOp::getEffects(
1365 &effects) {
1366 effects.emplace_back(MemoryEffects::Read::get(),
1368 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1370}
1371
1372// UseDeviceOp: RuntimeCounters read.
1373void acc::UseDeviceOp::getEffects(
1375 &effects) {
1376 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1377 effects.emplace_back(MemoryEffects::Read::get(),
1379}
1380
1381// DeclareDeviceResidentOp: RuntimeCounters write, var read.
1382void acc::DeclareDeviceResidentOp::getEffects(
1384 &effects) {
1385 effects.emplace_back(MemoryEffects::Write::get(),
1387 effects.emplace_back(MemoryEffects::Read::get(),
1389 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1390}
1391
1392// DeclareLinkOp: RuntimeCounters write, var read.
1393void acc::DeclareLinkOp::getEffects(
1395 &effects) {
1396 effects.emplace_back(MemoryEffects::Write::get(),
1398 effects.emplace_back(MemoryEffects::Read::get(),
1400 addOperandEffect<MemoryEffects::Read>(effects, getVarMutable());
1401}
1402
1403// CacheOp: NoMemoryEffect
1404void acc::CacheOp::getEffects(
1406 &effects) {}
1407
1408// CopyoutOp: RuntimeCounters read+write, accVar read, var write.
1409void acc::CopyoutOp::getEffects(
1411 &effects) {
1412 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1413 effects.emplace_back(MemoryEffects::Write::get(),
1415 effects.emplace_back(MemoryEffects::Read::get(),
1417 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1418 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1419}
1420
1421// DeleteOp: RuntimeCounters read+write, accVar read.
1422void acc::DeleteOp::getEffects(
1424 &effects) {
1425 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1426 effects.emplace_back(MemoryEffects::Write::get(),
1428 effects.emplace_back(MemoryEffects::Read::get(),
1430 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1431}
1432
1433// DetachOp: RuntimeCounters read+write, accVar read.
1434void acc::DetachOp::getEffects(
1436 &effects) {
1437 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1438 effects.emplace_back(MemoryEffects::Write::get(),
1440 effects.emplace_back(MemoryEffects::Read::get(),
1442 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1443}
1444
1445// UpdateHostOp: RuntimeCounters read+write, accVar read, var write.
1446void acc::UpdateHostOp::getEffects(
1448 &effects) {
1449 effects.emplace_back(MemoryEffects::Read::get(), acc::RuntimeCounters::get());
1450 effects.emplace_back(MemoryEffects::Write::get(),
1452 effects.emplace_back(MemoryEffects::Read::get(),
1454 addOperandEffect<MemoryEffects::Read>(effects, getAccVarMutable());
1455 addOperandEffect<MemoryEffects::Write>(effects, getVarMutable());
1456}
1457
1458template <typename StructureOp>
1459static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
1460 unsigned nRegions = 1) {
1461
1463 for (unsigned i = 0; i < nRegions; ++i)
1464 regions.push_back(state.addRegion());
1465
1466 for (Region *region : regions)
1467 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
1468 return failure();
1469
1470 return success();
1471}
1472
1473namespace {
1474/// Pattern to remove operation without region that have constant false `ifCond`
1475/// and remove the condition from the operation if the `ifCond` is a true
1476/// constant.
1477template <typename OpTy>
1478struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1479 using OpRewritePattern<OpTy>::OpRewritePattern;
1480
1481 LogicalResult matchAndRewrite(OpTy op,
1482 PatternRewriter &rewriter) const override {
1483 // Early return if there is no condition.
1484 Value ifCond = op.getIfCond();
1485 if (!ifCond)
1486 return failure();
1487
1488 IntegerAttr constAttr;
1489 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1490 return failure();
1491 if (constAttr.getInt())
1492 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1493 else
1494 rewriter.eraseOp(op);
1495
1496 return success();
1497 }
1498};
1499
1500/// Replaces the given op with the contents of the given single-block region,
1501/// using the operands of the block terminator to replace operation results.
1502static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1503 Region &region, ValueRange blockArgs = {}) {
1504 assert(region.hasOneBlock() && "expected single-block region");
1505 Block *block = &region.front();
1506 Operation *terminator = block->getTerminator();
1507 ValueRange results = terminator->getOperands();
1508 rewriter.inlineBlockBefore(block, op, blockArgs);
1509 rewriter.replaceOp(op, results);
1510 rewriter.eraseOp(terminator);
1511}
1512
1513/// Pattern to remove operation with region that have constant false `ifCond`
1514/// and remove the condition from the operation if the `ifCond` is constant
1515/// true.
1516template <typename OpTy>
1517struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1518 using OpRewritePattern<OpTy>::OpRewritePattern;
1519
1520 LogicalResult matchAndRewrite(OpTy op,
1521 PatternRewriter &rewriter) const override {
1522 // Early return if there is no condition.
1523 Value ifCond = op.getIfCond();
1524 if (!ifCond)
1525 return failure();
1526
1527 IntegerAttr constAttr;
1528 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1529 return failure();
1530 if (constAttr.getInt())
1531 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1532 else
1533 replaceOpWithRegion(rewriter, op, op.getRegion());
1534
1535 return success();
1536 }
1537};
1538
1539//===----------------------------------------------------------------------===//
1540// Recipe Region Helpers
1541//===----------------------------------------------------------------------===//
1542
1543/// Create and populate an init region for privatization recipes.
1544/// Returns success if the region is populated, failure otherwise.
1545/// Sets needsFree to indicate if the allocated memory requires deallocation.
1546static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1547 Region &initRegion, Type varType,
1548 StringRef varName, ValueRange bounds,
1549 bool &needsFree) {
1550 // Create init block with arguments: original value + bounds
1551 SmallVector<Type> argTypes{varType};
1552 SmallVector<Location> argLocs{loc};
1553 for (Value bound : bounds) {
1554 argTypes.push_back(bound.getType());
1555 argLocs.push_back(loc);
1556 }
1557
1558 Block *initBlock = builder.createBlock(&initRegion);
1559 initBlock->addArguments(argTypes, argLocs);
1560 builder.setInsertionPointToStart(initBlock);
1561
1562 Value privatizedValue;
1563
1564 // Get the block argument that represents the original variable
1565 Value blockArgVar = initBlock->getArgument(0);
1566
1567 // Generate init region body based on variable type
1568 if (isa<MappableType>(varType)) {
1569 auto mappableTy = cast<MappableType>(varType);
1570 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1571 privatizedValue = mappableTy.generatePrivateInit(
1572 builder, loc, typedVar, varName, bounds, {}, needsFree);
1573 if (!privatizedValue)
1574 return failure();
1575 } else {
1576 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1577 auto pointerLikeTy = cast<PointerLikeType>(varType);
1578 // Use PointerLikeType's allocation API with the block argument
1579 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1580 blockArgVar, needsFree);
1581 if (!privatizedValue)
1582 return failure();
1583 }
1584
1585 // Add yield operation to init block
1586 acc::YieldOp::create(builder, loc, privatizedValue);
1587
1588 return success();
1589}
1590
1591/// Create and populate a copy region for firstprivate recipes.
1592/// Returns success if the region is populated, failure otherwise.
1593/// TODO: Handle MappableType - it does not yet have a copy API.
1594static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1595 Region &copyRegion, Type varType,
1596 ValueRange bounds) {
1597 // Create copy block with arguments: original value + privatized value +
1598 // bounds
1599 SmallVector<Type> copyArgTypes{varType, varType};
1600 SmallVector<Location> copyArgLocs{loc, loc};
1601 for (Value bound : bounds) {
1602 copyArgTypes.push_back(bound.getType());
1603 copyArgLocs.push_back(loc);
1604 }
1605
1606 Block *copyBlock = builder.createBlock(&copyRegion);
1607 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1608 builder.setInsertionPointToStart(copyBlock);
1609
1610 bool isMappable = isa<MappableType>(varType);
1611 bool isPointerLike = isa<PointerLikeType>(varType);
1612 // TODO: Handle MappableType - it does not yet have a copy API.
1613 // Otherwise, for now just fallback to pointer-like behavior.
1614 if (isMappable && !isPointerLike)
1615 return failure();
1616
1617 // Generate copy region body based on variable type
1618 if (isPointerLike) {
1619 auto pointerLikeTy = cast<PointerLikeType>(varType);
1620 Value originalArg = copyBlock->getArgument(0);
1621 Value privatizedArg = copyBlock->getArgument(1);
1622
1623 // Generate copy operation using PointerLikeType interface
1624 if (!pointerLikeTy.genCopy(
1625 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1626 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1627 return failure();
1628 }
1629
1630 // Add terminator to copy block
1631 acc::TerminatorOp::create(builder, loc);
1632
1633 return success();
1634}
1635
1636/// Create and populate a destroy region for privatization recipes.
1637/// Returns success if the region is populated, failure otherwise.
1638static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1639 Region &destroyRegion, Type varType,
1640 Value allocRes, ValueRange bounds) {
1641 // Create destroy block with arguments: original value + privatized value +
1642 // bounds
1643 SmallVector<Type> destroyArgTypes{varType, varType};
1644 SmallVector<Location> destroyArgLocs{loc, loc};
1645 for (Value bound : bounds) {
1646 destroyArgTypes.push_back(bound.getType());
1647 destroyArgLocs.push_back(loc);
1648 }
1649
1650 Block *destroyBlock = builder.createBlock(&destroyRegion);
1651 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1652 builder.setInsertionPointToStart(destroyBlock);
1653
1654 auto varToFree =
1655 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1656 if (isa<MappableType>(varType)) {
1657 auto mappableTy = cast<MappableType>(varType);
1658 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree, bounds))
1659 return failure();
1660 } else {
1661 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1662 auto pointerLikeTy = cast<PointerLikeType>(varType);
1663 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1664 return failure();
1665 }
1666
1667 acc::TerminatorOp::create(builder, loc);
1668 return success();
1669}
1670
1671} // namespace
1672
1673//===----------------------------------------------------------------------===//
1674// PrivateRecipeOp
1675//===----------------------------------------------------------------------===//
1676
1678 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1679 Type type, bool verifyYield, bool optional = false) {
1680 if (optional && region.empty())
1681 return success();
1682
1683 if (region.empty())
1684 return op->emitOpError() << "expects non-empty " << regionName << " region";
1685 Block &firstBlock = region.front();
1686 if (firstBlock.getNumArguments() < 1 ||
1687 firstBlock.getArgument(0).getType() != type)
1688 return op->emitOpError() << "expects " << regionName
1689 << " region first "
1690 "argument of the "
1691 << regionType << " type";
1692
1693 if (verifyYield) {
1694 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1695 if (yieldOp.getOperands().size() != 1 ||
1696 yieldOp.getOperands().getTypes()[0] != type)
1697 return op->emitOpError() << "expects " << regionName
1698 << " region to "
1699 "yield a value of the "
1700 << regionType << " type";
1701 }
1702 }
1703 return success();
1704}
1705
1706LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1707 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1708 "privatization", "init", getType(),
1709 /*verifyYield=*/false)))
1710 return failure();
1712 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1713 /*verifyYield=*/false, /*optional=*/true)))
1714 return failure();
1715 return success();
1716}
1717
1718std::optional<PrivateRecipeOp>
1719PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1720 StringRef recipeName, Type varType,
1721 StringRef varName, ValueRange bounds) {
1722 // First, validate that we can handle this variable type
1723 bool isMappable = isa<MappableType>(varType);
1724 bool isPointerLike = isa<PointerLikeType>(varType);
1725
1726 // Unsupported type
1727 if (!isMappable && !isPointerLike)
1728 return std::nullopt;
1729
1730 OpBuilder::InsertionGuard guard(builder);
1731
1732 // Create the recipe operation first so regions have proper parent context
1733 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1734
1735 // Populate the init region
1736 bool needsFree = false;
1737 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1738 varName, bounds, needsFree))) {
1739 recipe.erase();
1740 return std::nullopt;
1741 }
1742
1743 // Only create destroy region if the allocation needs deallocation
1744 if (needsFree) {
1745 // Extract the allocated value from the init block's yield operation
1746 auto yieldOp =
1747 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1748 Value allocRes = yieldOp.getOperand(0);
1749
1750 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1751 varType, allocRes, bounds))) {
1752 recipe.erase();
1753 return std::nullopt;
1754 }
1755 }
1756
1757 return recipe;
1758}
1759
1760std::optional<PrivateRecipeOp>
1761PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1762 StringRef recipeName,
1763 FirstprivateRecipeOp firstprivRecipe) {
1764 // Create the private.recipe op with the same type as the firstprivate.recipe.
1765 OpBuilder::InsertionGuard guard(builder);
1766 auto varType = firstprivRecipe.getType();
1767 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1768
1769 // Clone the init region
1770 IRMapping mapping;
1771 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1772
1773 // Clone destroy region if the firstprivate.recipe has one.
1774 if (!firstprivRecipe.getDestroyRegion().empty()) {
1775 IRMapping mapping;
1776 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1777 mapping);
1778 }
1779 return recipe;
1780}
1781
1782//===----------------------------------------------------------------------===//
1783// FirstprivateRecipeOp
1784//===----------------------------------------------------------------------===//
1785
1786LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1787 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1788 "privatization", "init", getType(),
1789 /*verifyYield=*/false)))
1790 return failure();
1791
1792 if (getCopyRegion().empty())
1793 return emitOpError() << "expects non-empty copy region";
1794
1795 Block &firstBlock = getCopyRegion().front();
1796 if (firstBlock.getNumArguments() < 2 ||
1797 firstBlock.getArgument(0).getType() != getType())
1798 return emitOpError() << "expects copy region with two arguments of the "
1799 "privatization type";
1800
1801 if (getDestroyRegion().empty())
1802 return success();
1803
1804 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1805 "privatization", "destroy",
1806 getType(), /*verifyYield=*/false)))
1807 return failure();
1808
1809 return success();
1810}
1811
1812std::optional<FirstprivateRecipeOp>
1813FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1814 StringRef recipeName, Type varType,
1815 StringRef varName, ValueRange bounds) {
1816 // First, validate that we can handle this variable type
1817 bool isMappable = isa<MappableType>(varType);
1818 bool isPointerLike = isa<PointerLikeType>(varType);
1819
1820 // Unsupported type
1821 if (!isMappable && !isPointerLike)
1822 return std::nullopt;
1823
1824 OpBuilder::InsertionGuard guard(builder);
1825
1826 // Create the recipe operation first so regions have proper parent context
1827 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1828
1829 // Populate the init region
1830 bool needsFree = false;
1831 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1832 varName, bounds, needsFree))) {
1833 recipe.erase();
1834 return std::nullopt;
1835 }
1836
1837 // Populate the copy region
1838 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1839 bounds))) {
1840 recipe.erase();
1841 return std::nullopt;
1842 }
1843
1844 // Only create destroy region if the allocation needs deallocation
1845 if (needsFree) {
1846 // Extract the allocated value from the init block's yield operation
1847 auto yieldOp =
1848 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1849 Value allocRes = yieldOp.getOperand(0);
1850
1851 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1852 varType, allocRes, bounds))) {
1853 recipe.erase();
1854 return std::nullopt;
1855 }
1856 }
1857
1858 return recipe;
1859}
1860
1861//===----------------------------------------------------------------------===//
1862// ReductionRecipeOp
1863//===----------------------------------------------------------------------===//
1864
1865LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1866 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1867 "init", getType(),
1868 /*verifyYield=*/false)))
1869 return failure();
1870
1871 if (getCombinerRegion().empty())
1872 return emitOpError() << "expects non-empty combiner region";
1873
1874 Block &reductionBlock = getCombinerRegion().front();
1875 if (reductionBlock.getNumArguments() < 2 ||
1876 reductionBlock.getArgument(0).getType() != getType() ||
1877 reductionBlock.getArgument(1).getType() != getType())
1878 return emitOpError() << "expects combiner region with the first two "
1879 << "arguments of the reduction type";
1880
1881 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1882 if (yieldOp.getOperands().size() != 1 ||
1883 yieldOp.getOperands().getTypes()[0] != getType())
1884 return emitOpError() << "expects combiner region to yield a value "
1885 "of the reduction type";
1886 }
1887
1888 return success();
1889}
1890
1891//===----------------------------------------------------------------------===//
1892// ParallelOp
1893//===----------------------------------------------------------------------===//
1894
1895/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1896template <typename Op>
1897static LogicalResult checkDataOperands(Op op,
1898 const mlir::ValueRange &operands) {
1899 for (mlir::Value operand : operands)
1900 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1901 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1902 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1903 operand.getDefiningOp()))
1904 return op.emitError(
1905 "expect data entry/exit operation or acc.getdeviceptr "
1906 "as defining op");
1907 return success();
1908}
1909
1910template <typename OpT, typename RecipeOpT>
1911static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
1912 const mlir::ValueRange &operands,
1913 llvm::StringRef operandName) {
1915 for (mlir::Value operand : operands) {
1916 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1917 return accConstructOp->emitOpError()
1918 << "expected " << operandName << " as defining op";
1919 if (!set.insert(operand).second)
1920 return accConstructOp->emitOpError()
1921 << operandName << " operand appears more than once";
1922 }
1923 return success();
1924}
1925
1926unsigned ParallelOp::getNumDataOperands() {
1927 return getReductionOperands().size() + getPrivateOperands().size() +
1928 getFirstprivateOperands().size() + getDataClauseOperands().size();
1929}
1930
1931Value ParallelOp::getDataOperand(unsigned i) {
1932 unsigned numOptional = getAsyncOperands().size();
1933 numOptional += getNumGangs().size();
1934 numOptional += getNumWorkers().size();
1935 numOptional += getVectorLength().size();
1936 numOptional += getIfCond() ? 1 : 0;
1937 numOptional += getSelfCond() ? 1 : 0;
1938 return getOperand(getWaitOperands().size() + numOptional + i);
1939}
1940
1941template <typename Op>
1942static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1943 ArrayAttr deviceTypes,
1944 llvm::StringRef keyword) {
1945 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1946 return op.emitOpError() << keyword << " operands count must match "
1947 << keyword << " device_type count";
1948 return success();
1949}
1950
1951template <typename Op>
1953 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1954 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1955 std::size_t numOperandsInSegments = 0;
1956 std::size_t nbOfSegments = 0;
1957
1958 if (segments) {
1959 for (auto segCount : segments.asArrayRef()) {
1960 if (maxInSegment != 0 && segCount > maxInSegment)
1961 return op.emitOpError() << keyword << " expects a maximum of "
1962 << maxInSegment << " values per segment";
1963 numOperandsInSegments += segCount;
1964 ++nbOfSegments;
1965 }
1966 }
1967
1968 if ((numOperandsInSegments != operands.size()) ||
1969 (!deviceTypes && !operands.empty()))
1970 return op.emitOpError()
1971 << keyword << " operand count does not match count in segments";
1972 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1973 return op.emitOpError()
1974 << keyword << " segment count does not match device_type count";
1975 return success();
1976}
1977
1978LogicalResult acc::ParallelOp::verify() {
1979 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
1980 mlir::acc::PrivateRecipeOp>(
1981 *this, getPrivateOperands(), "private")))
1982 return failure();
1983 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
1984 mlir::acc::FirstprivateRecipeOp>(
1985 *this, getFirstprivateOperands(), "firstprivate")))
1986 return failure();
1987 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
1988 mlir::acc::ReductionRecipeOp>(
1989 *this, getReductionOperands(), "reduction")))
1990 return failure();
1991
1993 *this, getNumGangs(), getNumGangsSegmentsAttr(),
1994 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1995 return failure();
1996
1998 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1999 getWaitOperandsDeviceTypeAttr(), "wait")))
2000 return failure();
2001
2002 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2003 getNumWorkersDeviceTypeAttr(),
2004 "num_workers")))
2005 return failure();
2006
2007 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2008 getVectorLengthDeviceTypeAttr(),
2009 "vector_length")))
2010 return failure();
2011
2013 getAsyncOperandsDeviceTypeAttr(),
2014 "async")))
2015 return failure();
2016
2018 return failure();
2019
2020 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
2021}
2022
2023static mlir::Value
2024getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
2026 mlir::acc::DeviceType deviceType) {
2027 if (!arrayAttr)
2028 return {};
2029 if (auto pos = findSegment(*arrayAttr, deviceType))
2030 return range[*pos];
2031 return {};
2032}
2033
2034bool acc::ParallelOp::hasAsyncOnly() {
2035 return hasAsyncOnly(mlir::acc::DeviceType::None);
2036}
2037
2038bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2039 return hasDeviceType(getAsyncOnly(), deviceType);
2040}
2041
2042mlir::Value acc::ParallelOp::getAsyncValue() {
2043 return getAsyncValue(mlir::acc::DeviceType::None);
2044}
2045
2046mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2048 getAsyncOperands(), deviceType);
2049}
2050
2051mlir::Value acc::ParallelOp::getNumWorkersValue() {
2052 return getNumWorkersValue(mlir::acc::DeviceType::None);
2053}
2054
2056acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2057 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2058 deviceType);
2059}
2060
2061mlir::Value acc::ParallelOp::getVectorLengthValue() {
2062 return getVectorLengthValue(mlir::acc::DeviceType::None);
2063}
2064
2066acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2067 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2068 getVectorLength(), deviceType);
2069}
2070
2071mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
2072 return getNumGangsValues(mlir::acc::DeviceType::None);
2073}
2074
2076ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2077 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2078 getNumGangsSegments(), deviceType);
2079}
2080
2081bool acc::ParallelOp::hasWaitOnly() {
2082 return hasWaitOnly(mlir::acc::DeviceType::None);
2083}
2084
2085bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2086 return hasDeviceType(getWaitOnly(), deviceType);
2087}
2088
2089mlir::Operation::operand_range ParallelOp::getWaitValues() {
2090 return getWaitValues(mlir::acc::DeviceType::None);
2091}
2092
2094ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2096 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2097 getHasWaitDevnum(), deviceType);
2098}
2099
2100mlir::Value ParallelOp::getWaitDevnum() {
2101 return getWaitDevnum(mlir::acc::DeviceType::None);
2102}
2103
2104mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2105 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2106 getWaitOperandsSegments(), getHasWaitDevnum(),
2107 deviceType);
2108}
2109
2110void ParallelOp::build(mlir::OpBuilder &odsBuilder,
2111 mlir::OperationState &odsState,
2112 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
2113 mlir::ValueRange vectorLength,
2114 mlir::ValueRange asyncOperands,
2115 mlir::ValueRange waitOperands, mlir::Value ifCond,
2116 mlir::Value selfCond, mlir::ValueRange reductionOperands,
2117 mlir::ValueRange gangPrivateOperands,
2118 mlir::ValueRange gangFirstPrivateOperands,
2119 mlir::ValueRange dataClauseOperands) {
2120 ParallelOp::build(
2121 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
2122 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
2123 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
2124 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
2125 /*numGangsDeviceType=*/nullptr, numWorkers,
2126 /*numWorkersDeviceType=*/nullptr, vectorLength,
2127 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
2128 /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
2129 gangFirstPrivateOperands, dataClauseOperands,
2130 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
2131}
2132
2133void acc::ParallelOp::addNumWorkersOperand(
2134 MLIRContext *context, mlir::Value newValue,
2135 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2136 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2137 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2138 getNumWorkersMutable()));
2139}
2140void acc::ParallelOp::addVectorLengthOperand(
2141 MLIRContext *context, mlir::Value newValue,
2142 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2143 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2144 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2145 getVectorLengthMutable()));
2146}
2147
2148void acc::ParallelOp::addAsyncOnly(
2149 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2150 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2151 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2152}
2153
2154void acc::ParallelOp::addAsyncOperand(
2155 MLIRContext *context, mlir::Value newValue,
2156 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2157 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2158 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2159 getAsyncOperandsMutable()));
2160}
2161
2162void acc::ParallelOp::addNumGangsOperands(
2163 MLIRContext *context, mlir::ValueRange newValues,
2164 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2166 if (getNumGangsSegments())
2167 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2168
2169 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2170 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2171 getNumGangsMutable(), segments));
2172
2173 setNumGangsSegments(segments);
2174}
2175void acc::ParallelOp::addWaitOnly(
2176 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2177 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2178 effectiveDeviceTypes));
2179}
2180void acc::ParallelOp::addWaitOperands(
2181 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2182 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2183
2185 if (getWaitOperandsSegments())
2186 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2187
2188 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2189 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2190 getWaitOperandsMutable(), segments));
2191 setWaitOperandsSegments(segments);
2192
2194 if (getHasWaitDevnumAttr())
2195 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2196 hasDevnums.insert(
2197 hasDevnums.end(),
2198 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2199 mlir::BoolAttr::get(context, hasDevnum));
2200 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2201}
2202
2203void acc::ParallelOp::addPrivatization(MLIRContext *context,
2204 mlir::acc::PrivateOp op,
2205 mlir::acc::PrivateRecipeOp recipe) {
2206 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2207 getPrivateOperandsMutable().append(op.getResult());
2208}
2209
2210void acc::ParallelOp::addFirstPrivatization(
2211 MLIRContext *context, mlir::acc::FirstprivateOp op,
2212 mlir::acc::FirstprivateRecipeOp recipe) {
2213 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2214 getFirstprivateOperandsMutable().append(op.getResult());
2215}
2216
2217void acc::ParallelOp::addReduction(MLIRContext *context,
2218 mlir::acc::ReductionOp op,
2219 mlir::acc::ReductionRecipeOp recipe) {
2220 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2221 getReductionOperandsMutable().append(op.getResult());
2222}
2223
2224static ParseResult parseNumGangs(
2225 mlir::OpAsmParser &parser,
2227 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2228 mlir::DenseI32ArrayAttr &segments) {
2231
2232 do {
2233 if (failed(parser.parseLBrace()))
2234 return failure();
2235
2236 int32_t crtOperandsSize = operands.size();
2237 if (failed(parser.parseCommaSeparatedList(
2239 if (parser.parseOperand(operands.emplace_back()) ||
2240 parser.parseColonType(types.emplace_back()))
2241 return failure();
2242 return success();
2243 })))
2244 return failure();
2245 seg.push_back(operands.size() - crtOperandsSize);
2246
2247 if (failed(parser.parseRBrace()))
2248 return failure();
2249
2250 if (succeeded(parser.parseOptionalLSquare())) {
2251 if (parser.parseAttribute(attributes.emplace_back()) ||
2252 parser.parseRSquare())
2253 return failure();
2254 } else {
2255 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2256 parser.getContext(), mlir::acc::DeviceType::None));
2257 }
2258 } while (succeeded(parser.parseOptionalComma()));
2259
2260 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2261 attributes.end());
2262 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2263 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2264
2265 return success();
2266}
2267
2269 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2270 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2271 p << " [" << attr << "]";
2272}
2273
2275 mlir::OperandRange operands, mlir::TypeRange types,
2276 std::optional<mlir::ArrayAttr> deviceTypes,
2277 std::optional<mlir::DenseI32ArrayAttr> segments) {
2278 unsigned opIdx = 0;
2279 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2280 p << "{";
2281 llvm::interleaveComma(
2282 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2283 p << operands[opIdx] << " : " << operands[opIdx].getType();
2284 ++opIdx;
2285 });
2286 p << "}";
2287 printSingleDeviceType(p, it.value());
2288 });
2289}
2290
2292 mlir::OpAsmParser &parser,
2294 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2295 mlir::DenseI32ArrayAttr &segments) {
2298
2299 do {
2300 if (failed(parser.parseLBrace()))
2301 return failure();
2302
2303 int32_t crtOperandsSize = operands.size();
2304
2305 if (failed(parser.parseCommaSeparatedList(
2307 if (parser.parseOperand(operands.emplace_back()) ||
2308 parser.parseColonType(types.emplace_back()))
2309 return failure();
2310 return success();
2311 })))
2312 return failure();
2313
2314 seg.push_back(operands.size() - crtOperandsSize);
2315
2316 if (failed(parser.parseRBrace()))
2317 return failure();
2318
2319 if (succeeded(parser.parseOptionalLSquare())) {
2320 if (parser.parseAttribute(attributes.emplace_back()) ||
2321 parser.parseRSquare())
2322 return failure();
2323 } else {
2324 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2325 parser.getContext(), mlir::acc::DeviceType::None));
2326 }
2327 } while (succeeded(parser.parseOptionalComma()));
2328
2329 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2330 attributes.end());
2331 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2332 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2333
2334 return success();
2335}
2336
2339 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2340 std::optional<mlir::DenseI32ArrayAttr> segments) {
2341 unsigned opIdx = 0;
2342 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2343 p << "{";
2344 llvm::interleaveComma(
2345 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2346 p << operands[opIdx] << " : " << operands[opIdx].getType();
2347 ++opIdx;
2348 });
2349 p << "}";
2350 printSingleDeviceType(p, it.value());
2351 });
2352}
2353
2354static ParseResult parseWaitClause(
2355 mlir::OpAsmParser &parser,
2357 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2358 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2359 mlir::ArrayAttr &keywordOnly) {
2360 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2362
2363 bool needCommaBeforeOperands = false;
2364
2365 // Keyword only
2366 if (failed(parser.parseOptionalLParen())) {
2367 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2368 parser.getContext(), mlir::acc::DeviceType::None));
2369 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2370 return success();
2371 }
2372
2373 // Parse keyword only attributes
2374 if (succeeded(parser.parseOptionalLSquare())) {
2375 if (failed(parser.parseCommaSeparatedList([&]() {
2376 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2377 return failure();
2378 return success();
2379 })))
2380 return failure();
2381 if (parser.parseRSquare())
2382 return failure();
2383 needCommaBeforeOperands = true;
2384 }
2385
2386 if (needCommaBeforeOperands && failed(parser.parseComma()))
2387 return failure();
2388
2389 do {
2390 if (failed(parser.parseLBrace()))
2391 return failure();
2392
2393 int32_t crtOperandsSize = operands.size();
2394
2395 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2396 if (failed(parser.parseColon()))
2397 return failure();
2398 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2399 } else {
2400 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2401 }
2402
2403 if (failed(parser.parseCommaSeparatedList(
2405 if (parser.parseOperand(operands.emplace_back()) ||
2406 parser.parseColonType(types.emplace_back()))
2407 return failure();
2408 return success();
2409 })))
2410 return failure();
2411
2412 seg.push_back(operands.size() - crtOperandsSize);
2413
2414 if (failed(parser.parseRBrace()))
2415 return failure();
2416
2417 if (succeeded(parser.parseOptionalLSquare())) {
2418 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2419 parser.parseRSquare())
2420 return failure();
2421 } else {
2422 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2423 parser.getContext(), mlir::acc::DeviceType::None));
2424 }
2425 } while (succeeded(parser.parseOptionalComma()));
2426
2427 if (failed(parser.parseRParen()))
2428 return failure();
2429
2430 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2431 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2432 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2433 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2434
2435 return success();
2436}
2437
2438static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2439 if (!hasDeviceTypeValues(attrs))
2440 return false;
2441 if (attrs->size() != 1)
2442 return false;
2443 if (auto deviceTypeAttr =
2444 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2445 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2446 return false;
2447}
2448
2450 mlir::OperandRange operands, mlir::TypeRange types,
2451 std::optional<mlir::ArrayAttr> deviceTypes,
2452 std::optional<mlir::DenseI32ArrayAttr> segments,
2453 std::optional<mlir::ArrayAttr> hasDevNum,
2454 std::optional<mlir::ArrayAttr> keywordOnly) {
2455
2456 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2457 return;
2458
2459 p << "(";
2460
2461 printDeviceTypes(p, keywordOnly);
2462 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2463 p << ", ";
2464
2465 if (hasDeviceTypeValues(deviceTypes)) {
2466 unsigned opIdx = 0;
2467 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2468 p << "{";
2469 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2470 if (boolAttr && boolAttr.getValue())
2471 p << "devnum: ";
2472 llvm::interleaveComma(
2473 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2474 p << operands[opIdx] << " : " << operands[opIdx].getType();
2475 ++opIdx;
2476 });
2477 p << "}";
2478 printSingleDeviceType(p, it.value());
2479 });
2480 }
2481
2482 p << ")";
2483}
2484
2485static ParseResult parseDeviceTypeOperands(
2486 mlir::OpAsmParser &parser,
2488 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2490 if (failed(parser.parseCommaSeparatedList([&]() {
2491 if (parser.parseOperand(operands.emplace_back()) ||
2492 parser.parseColonType(types.emplace_back()))
2493 return failure();
2494 if (succeeded(parser.parseOptionalLSquare())) {
2495 if (parser.parseAttribute(attributes.emplace_back()) ||
2496 parser.parseRSquare())
2497 return failure();
2498 } else {
2499 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2500 parser.getContext(), mlir::acc::DeviceType::None));
2501 }
2502 return success();
2503 })))
2504 return failure();
2505 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2506 attributes.end());
2507 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2508 return success();
2509}
2510
2511static void
2513 mlir::OperandRange operands, mlir::TypeRange types,
2514 std::optional<mlir::ArrayAttr> deviceTypes) {
2515 if (!hasDeviceTypeValues(deviceTypes))
2516 return;
2517 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2518 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2519 printSingleDeviceType(p, std::get<0>(it));
2520 });
2521}
2522
2524 mlir::OpAsmParser &parser,
2526 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2527 mlir::ArrayAttr &keywordOnlyDeviceType) {
2528
2529 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2530 bool needCommaBeforeOperands = false;
2531
2532 if (failed(parser.parseOptionalLParen())) {
2533 // Keyword only
2534 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2535 parser.getContext(), mlir::acc::DeviceType::None));
2536 keywordOnlyDeviceType =
2537 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2538 return success();
2539 }
2540
2541 // Parse keyword only attributes
2542 if (succeeded(parser.parseOptionalLSquare())) {
2543 // Parse keyword only attributes
2544 if (failed(parser.parseCommaSeparatedList([&]() {
2545 if (parser.parseAttribute(
2546 keywordOnlyDeviceTypeAttributes.emplace_back()))
2547 return failure();
2548 return success();
2549 })))
2550 return failure();
2551 if (parser.parseRSquare())
2552 return failure();
2553 needCommaBeforeOperands = true;
2554 }
2555
2556 if (needCommaBeforeOperands && failed(parser.parseComma()))
2557 return failure();
2558
2560 if (failed(parser.parseCommaSeparatedList([&]() {
2561 if (parser.parseOperand(operands.emplace_back()) ||
2562 parser.parseColonType(types.emplace_back()))
2563 return failure();
2564 if (succeeded(parser.parseOptionalLSquare())) {
2565 if (parser.parseAttribute(attributes.emplace_back()) ||
2566 parser.parseRSquare())
2567 return failure();
2568 } else {
2569 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2570 parser.getContext(), mlir::acc::DeviceType::None));
2571 }
2572 return success();
2573 })))
2574 return failure();
2575
2576 if (failed(parser.parseRParen()))
2577 return failure();
2578
2579 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2580 attributes.end());
2581 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2582 return success();
2583}
2584
2587 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2588 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2589
2590 if (operands.begin() == operands.end() &&
2591 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2592 return;
2593 }
2594
2595 p << "(";
2596 printDeviceTypes(p, keywordOnlyDeviceTypes);
2597 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2598 hasDeviceTypeValues(deviceTypes))
2599 p << ", ";
2600 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2601 p << ")";
2602}
2603
2605 mlir::OpAsmParser &parser,
2606 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2607 mlir::Type &operandType, mlir::UnitAttr &attr) {
2608 // Keyword only
2609 if (failed(parser.parseOptionalLParen())) {
2610 attr = mlir::UnitAttr::get(parser.getContext());
2611 return success();
2612 }
2613
2615 if (failed(parser.parseOperand(op)))
2616 return failure();
2617 operand = op;
2618 if (failed(parser.parseColon()))
2619 return failure();
2620 if (failed(parser.parseType(operandType)))
2621 return failure();
2622 if (failed(parser.parseRParen()))
2623 return failure();
2624
2625 return success();
2626}
2627
2629 mlir::Operation *op,
2630 std::optional<mlir::Value> operand,
2631 mlir::Type operandType,
2632 mlir::UnitAttr attr) {
2633 if (attr)
2634 return;
2635
2636 p << "(";
2637 p.printOperand(*operand);
2638 p << " : ";
2639 p.printType(operandType);
2640 p << ")";
2641}
2642
2644 mlir::OpAsmParser &parser,
2646 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2647 // Keyword only
2648 if (failed(parser.parseOptionalLParen())) {
2649 attr = mlir::UnitAttr::get(parser.getContext());
2650 return success();
2651 }
2652
2653 if (failed(parser.parseCommaSeparatedList([&]() {
2654 if (parser.parseOperand(operands.emplace_back()))
2655 return failure();
2656 return success();
2657 })))
2658 return failure();
2659 if (failed(parser.parseColon()))
2660 return failure();
2661 if (failed(parser.parseCommaSeparatedList([&]() {
2662 if (parser.parseType(types.emplace_back()))
2663 return failure();
2664 return success();
2665 })))
2666 return failure();
2667 if (failed(parser.parseRParen()))
2668 return failure();
2669
2670 return success();
2671}
2672
2674 mlir::Operation *op,
2675 mlir::OperandRange operands,
2676 mlir::TypeRange types,
2677 mlir::UnitAttr attr) {
2678 if (attr)
2679 return;
2680
2681 p << "(";
2682 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2683 p << " : ";
2684 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2685 p << ")";
2686}
2687
2688static ParseResult
2690 mlir::acc::CombinedConstructsTypeAttr &attr) {
2691 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2692 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2693 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2694 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2695 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2696 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2697 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2698 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2699 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2700 } else {
2701 parser.emitError(parser.getCurrentLocation(),
2702 "expected compute construct name");
2703 return failure();
2704 }
2705 return success();
2706}
2707
2708static void
2710 mlir::acc::CombinedConstructsTypeAttr attr) {
2711 if (attr) {
2712 switch (attr.getValue()) {
2713 case mlir::acc::CombinedConstructsType::KernelsLoop:
2714 p << "kernels";
2715 break;
2716 case mlir::acc::CombinedConstructsType::ParallelLoop:
2717 p << "parallel";
2718 break;
2719 case mlir::acc::CombinedConstructsType::SerialLoop:
2720 p << "serial";
2721 break;
2722 };
2723 }
2724}
2725
2726//===----------------------------------------------------------------------===//
2727// SerialOp
2728//===----------------------------------------------------------------------===//
2729
2730unsigned SerialOp::getNumDataOperands() {
2731 return getReductionOperands().size() + getPrivateOperands().size() +
2732 getFirstprivateOperands().size() + getDataClauseOperands().size();
2733}
2734
2735Value SerialOp::getDataOperand(unsigned i) {
2736 unsigned numOptional = getAsyncOperands().size();
2737 numOptional += getIfCond() ? 1 : 0;
2738 numOptional += getSelfCond() ? 1 : 0;
2739 return getOperand(getWaitOperands().size() + numOptional + i);
2740}
2741
2742bool acc::SerialOp::hasAsyncOnly() {
2743 return hasAsyncOnly(mlir::acc::DeviceType::None);
2744}
2745
2746bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2747 return hasDeviceType(getAsyncOnly(), deviceType);
2748}
2749
2750mlir::Value acc::SerialOp::getAsyncValue() {
2751 return getAsyncValue(mlir::acc::DeviceType::None);
2752}
2753
2754mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2756 getAsyncOperands(), deviceType);
2757}
2758
2759bool acc::SerialOp::hasWaitOnly() {
2760 return hasWaitOnly(mlir::acc::DeviceType::None);
2761}
2762
2763bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2764 return hasDeviceType(getWaitOnly(), deviceType);
2765}
2766
2767mlir::Operation::operand_range SerialOp::getWaitValues() {
2768 return getWaitValues(mlir::acc::DeviceType::None);
2769}
2770
2772SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2774 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2775 getHasWaitDevnum(), deviceType);
2776}
2777
2778mlir::Value SerialOp::getWaitDevnum() {
2779 return getWaitDevnum(mlir::acc::DeviceType::None);
2780}
2781
2782mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2783 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2784 getWaitOperandsSegments(), getHasWaitDevnum(),
2785 deviceType);
2786}
2787
2788LogicalResult acc::SerialOp::verify() {
2789 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2790 mlir::acc::PrivateRecipeOp>(
2791 *this, getPrivateOperands(), "private")))
2792 return failure();
2793 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2794 mlir::acc::FirstprivateRecipeOp>(
2795 *this, getFirstprivateOperands(), "firstprivate")))
2796 return failure();
2797 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2798 mlir::acc::ReductionRecipeOp>(
2799 *this, getReductionOperands(), "reduction")))
2800 return failure();
2801
2803 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2804 getWaitOperandsDeviceTypeAttr(), "wait")))
2805 return failure();
2806
2808 getAsyncOperandsDeviceTypeAttr(),
2809 "async")))
2810 return failure();
2811
2813 return failure();
2814
2815 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2816}
2817
2818void acc::SerialOp::addAsyncOnly(
2819 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2820 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2821 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2822}
2823
2824void acc::SerialOp::addAsyncOperand(
2825 MLIRContext *context, mlir::Value newValue,
2826 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2827 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2828 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2829 getAsyncOperandsMutable()));
2830}
2831
2832void acc::SerialOp::addWaitOnly(
2833 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2834 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2835 effectiveDeviceTypes));
2836}
2837void acc::SerialOp::addWaitOperands(
2838 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2839 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2840
2842 if (getWaitOperandsSegments())
2843 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2844
2845 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2846 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2847 getWaitOperandsMutable(), segments));
2848 setWaitOperandsSegments(segments);
2849
2851 if (getHasWaitDevnumAttr())
2852 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2853 hasDevnums.insert(
2854 hasDevnums.end(),
2855 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2856 mlir::BoolAttr::get(context, hasDevnum));
2857 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2858}
2859
2860void acc::SerialOp::addPrivatization(MLIRContext *context,
2861 mlir::acc::PrivateOp op,
2862 mlir::acc::PrivateRecipeOp recipe) {
2863 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2864 getPrivateOperandsMutable().append(op.getResult());
2865}
2866
2867void acc::SerialOp::addFirstPrivatization(
2868 MLIRContext *context, mlir::acc::FirstprivateOp op,
2869 mlir::acc::FirstprivateRecipeOp recipe) {
2870 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2871 getFirstprivateOperandsMutable().append(op.getResult());
2872}
2873
2874void acc::SerialOp::addReduction(MLIRContext *context,
2875 mlir::acc::ReductionOp op,
2876 mlir::acc::ReductionRecipeOp recipe) {
2877 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2878 getReductionOperandsMutable().append(op.getResult());
2879}
2880
2881//===----------------------------------------------------------------------===//
2882// KernelsOp
2883//===----------------------------------------------------------------------===//
2884
2885unsigned KernelsOp::getNumDataOperands() {
2886 return getDataClauseOperands().size();
2887}
2888
2889Value KernelsOp::getDataOperand(unsigned i) {
2890 unsigned numOptional = getAsyncOperands().size();
2891 numOptional += getWaitOperands().size();
2892 numOptional += getNumGangs().size();
2893 numOptional += getNumWorkers().size();
2894 numOptional += getVectorLength().size();
2895 numOptional += getIfCond() ? 1 : 0;
2896 numOptional += getSelfCond() ? 1 : 0;
2897 return getOperand(numOptional + i);
2898}
2899
2900bool acc::KernelsOp::hasAsyncOnly() {
2901 return hasAsyncOnly(mlir::acc::DeviceType::None);
2902}
2903
2904bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2905 return hasDeviceType(getAsyncOnly(), deviceType);
2906}
2907
2908mlir::Value acc::KernelsOp::getAsyncValue() {
2909 return getAsyncValue(mlir::acc::DeviceType::None);
2910}
2911
2912mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2914 getAsyncOperands(), deviceType);
2915}
2916
2917mlir::Value acc::KernelsOp::getNumWorkersValue() {
2918 return getNumWorkersValue(mlir::acc::DeviceType::None);
2919}
2920
2922acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2923 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2924 deviceType);
2925}
2926
2927mlir::Value acc::KernelsOp::getVectorLengthValue() {
2928 return getVectorLengthValue(mlir::acc::DeviceType::None);
2929}
2930
2932acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2933 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2934 getVectorLength(), deviceType);
2935}
2936
2937mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2938 return getNumGangsValues(mlir::acc::DeviceType::None);
2939}
2940
2942KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2943 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2944 getNumGangsSegments(), deviceType);
2945}
2946
2947bool acc::KernelsOp::hasWaitOnly() {
2948 return hasWaitOnly(mlir::acc::DeviceType::None);
2949}
2950
2951bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2952 return hasDeviceType(getWaitOnly(), deviceType);
2953}
2954
2955mlir::Operation::operand_range KernelsOp::getWaitValues() {
2956 return getWaitValues(mlir::acc::DeviceType::None);
2957}
2958
2960KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2962 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2963 getHasWaitDevnum(), deviceType);
2964}
2965
2966mlir::Value KernelsOp::getWaitDevnum() {
2967 return getWaitDevnum(mlir::acc::DeviceType::None);
2968}
2969
2970mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2971 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2972 getWaitOperandsSegments(), getHasWaitDevnum(),
2973 deviceType);
2974}
2975
2976LogicalResult acc::KernelsOp::verify() {
2978 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2979 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2980 return failure();
2981
2983 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2984 getWaitOperandsDeviceTypeAttr(), "wait")))
2985 return failure();
2986
2987 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2988 getNumWorkersDeviceTypeAttr(),
2989 "num_workers")))
2990 return failure();
2991
2992 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2993 getVectorLengthDeviceTypeAttr(),
2994 "vector_length")))
2995 return failure();
2996
2998 getAsyncOperandsDeviceTypeAttr(),
2999 "async")))
3000 return failure();
3001
3003 return failure();
3004
3005 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
3006}
3007
3008void acc::KernelsOp::addPrivatization(MLIRContext *context,
3009 mlir::acc::PrivateOp op,
3010 mlir::acc::PrivateRecipeOp recipe) {
3011 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3012 getPrivateOperandsMutable().append(op.getResult());
3013}
3014
3015void acc::KernelsOp::addFirstPrivatization(
3016 MLIRContext *context, mlir::acc::FirstprivateOp op,
3017 mlir::acc::FirstprivateRecipeOp recipe) {
3018 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3019 getFirstprivateOperandsMutable().append(op.getResult());
3020}
3021
3022void acc::KernelsOp::addReduction(MLIRContext *context,
3023 mlir::acc::ReductionOp op,
3024 mlir::acc::ReductionRecipeOp recipe) {
3025 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3026 getReductionOperandsMutable().append(op.getResult());
3027}
3028
3029void acc::KernelsOp::addNumWorkersOperand(
3030 MLIRContext *context, mlir::Value newValue,
3031 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3032 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3033 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3034 getNumWorkersMutable()));
3035}
3036
3037void acc::KernelsOp::addVectorLengthOperand(
3038 MLIRContext *context, mlir::Value newValue,
3039 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3040 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3041 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3042 getVectorLengthMutable()));
3043}
3044void acc::KernelsOp::addAsyncOnly(
3045 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3046 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3047 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3048}
3049
3050void acc::KernelsOp::addAsyncOperand(
3051 MLIRContext *context, mlir::Value newValue,
3052 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3053 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3054 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3055 getAsyncOperandsMutable()));
3056}
3057
3058void acc::KernelsOp::addNumGangsOperands(
3059 MLIRContext *context, mlir::ValueRange newValues,
3060 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3062 if (getNumGangsSegmentsAttr())
3063 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
3064
3065 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3066 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3067 getNumGangsMutable(), segments));
3068
3069 setNumGangsSegments(segments);
3070}
3071
3072void acc::KernelsOp::addWaitOnly(
3073 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3074 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3075 effectiveDeviceTypes));
3076}
3077void acc::KernelsOp::addWaitOperands(
3078 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3079 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3080
3082 if (getWaitOperandsSegments())
3083 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3084
3085 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3086 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3087 getWaitOperandsMutable(), segments));
3088 setWaitOperandsSegments(segments);
3089
3091 if (getHasWaitDevnumAttr())
3092 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3093 hasDevnums.insert(
3094 hasDevnums.end(),
3095 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3096 mlir::BoolAttr::get(context, hasDevnum));
3097 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3098}
3099
3100//===----------------------------------------------------------------------===//
3101// HostDataOp
3102//===----------------------------------------------------------------------===//
3103
3104LogicalResult acc::HostDataOp::verify() {
3105 if (getDataClauseOperands().empty())
3106 return emitError("at least one operand must appear on the host_data "
3107 "operation");
3108
3110 for (mlir::Value operand : getDataClauseOperands()) {
3111 auto useDeviceOp =
3112 mlir::dyn_cast<acc::UseDeviceOp>(operand.getDefiningOp());
3113 if (!useDeviceOp)
3114 return emitError("expect data entry operation as defining op");
3115
3116 // Check for duplicate use_device clauses
3117 if (!seenVars.insert(useDeviceOp.getVar()).second)
3118 return emitError("duplicate use_device variable");
3119 }
3120 return success();
3121}
3122
3123void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3124 MLIRContext *context) {
3125 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
3126}
3127
3128//===----------------------------------------------------------------------===//
3129// LoopOp
3130//===----------------------------------------------------------------------===//
3131
3132static ParseResult parseGangValue(
3133 OpAsmParser &parser, llvm::StringRef keyword,
3136 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
3137 bool &needCommaBetweenValues, bool &newValue) {
3138 if (succeeded(parser.parseOptionalKeyword(keyword))) {
3139 if (parser.parseEqual())
3140 return failure();
3141 if (parser.parseOperand(operands.emplace_back()) ||
3142 parser.parseColonType(types.emplace_back()))
3143 return failure();
3144 attributes.push_back(gangArgType);
3145 needCommaBetweenValues = true;
3146 newValue = true;
3147 }
3148 return success();
3149}
3150
3151static ParseResult parseGangClause(
3152 OpAsmParser &parser,
3154 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
3155 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
3156 mlir::ArrayAttr &gangOnlyDeviceType) {
3157 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
3158 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
3159 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
3161 bool needCommaBetweenValues = false;
3162 bool needCommaBeforeOperands = false;
3163
3164 if (failed(parser.parseOptionalLParen())) {
3165 // Gang only keyword
3166 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3167 parser.getContext(), mlir::acc::DeviceType::None));
3168 gangOnlyDeviceType =
3169 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
3170 return success();
3171 }
3172
3173 // Parse gang only attributes
3174 if (succeeded(parser.parseOptionalLSquare())) {
3175 // Parse gang only attributes
3176 if (failed(parser.parseCommaSeparatedList([&]() {
3177 if (parser.parseAttribute(
3178 gangOnlyDeviceTypeAttributes.emplace_back()))
3179 return failure();
3180 return success();
3181 })))
3182 return failure();
3183 if (parser.parseRSquare())
3184 return failure();
3185 needCommaBeforeOperands = true;
3186 }
3187
3188 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3189 mlir::acc::GangArgType::Num);
3190 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
3191 mlir::acc::GangArgType::Dim);
3192 auto argStatic = mlir::acc::GangArgTypeAttr::get(
3193 parser.getContext(), mlir::acc::GangArgType::Static);
3194
3195 do {
3196 if (needCommaBeforeOperands) {
3197 needCommaBeforeOperands = false;
3198 continue;
3199 }
3200
3201 if (failed(parser.parseLBrace()))
3202 return failure();
3203
3204 int32_t crtOperandsSize = gangOperands.size();
3205 while (true) {
3206 bool newValue = false;
3207 bool needValue = false;
3208 if (needCommaBetweenValues) {
3209 if (succeeded(parser.parseOptionalComma()))
3210 needValue = true; // expect a new value after comma.
3211 else
3212 break;
3213 }
3214
3215 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
3216 gangOperands, gangOperandsType,
3217 gangArgTypeAttributes, argNum,
3218 needCommaBetweenValues, newValue)))
3219 return failure();
3220 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
3221 gangOperands, gangOperandsType,
3222 gangArgTypeAttributes, argDim,
3223 needCommaBetweenValues, newValue)))
3224 return failure();
3225 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
3226 gangOperands, gangOperandsType,
3227 gangArgTypeAttributes, argStatic,
3228 needCommaBetweenValues, newValue)))
3229 return failure();
3230
3231 if (!newValue && needValue) {
3232 parser.emitError(parser.getCurrentLocation(),
3233 "new value expected after comma");
3234 return failure();
3235 }
3236
3237 if (!newValue)
3238 break;
3239 }
3240
3241 if (gangOperands.empty())
3242 return parser.emitError(
3243 parser.getCurrentLocation(),
3244 "expect at least one of num, dim or static values");
3245
3246 if (failed(parser.parseRBrace()))
3247 return failure();
3248
3249 if (succeeded(parser.parseOptionalLSquare())) {
3250 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
3251 parser.parseRSquare())
3252 return failure();
3253 } else {
3254 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3255 parser.getContext(), mlir::acc::DeviceType::None));
3256 }
3257
3258 seg.push_back(gangOperands.size() - crtOperandsSize);
3259
3260 } while (succeeded(parser.parseOptionalComma()));
3261
3262 if (failed(parser.parseRParen()))
3263 return failure();
3264
3265 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
3266 gangArgTypeAttributes.end());
3267 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
3268 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
3269
3271 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3272 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
3273
3274 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
3275 return success();
3276}
3277
3279 mlir::OperandRange operands, mlir::TypeRange types,
3280 std::optional<mlir::ArrayAttr> gangArgTypes,
3281 std::optional<mlir::ArrayAttr> deviceTypes,
3282 std::optional<mlir::DenseI32ArrayAttr> segments,
3283 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3284
3285 if (operands.begin() == operands.end() &&
3286 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
3287 return;
3288 }
3289
3290 p << "(";
3291
3292 printDeviceTypes(p, gangOnlyDeviceTypes);
3293
3294 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
3295 hasDeviceTypeValues(deviceTypes))
3296 p << ", ";
3297
3298 if (hasDeviceTypeValues(deviceTypes)) {
3299 unsigned opIdx = 0;
3300 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
3301 p << "{";
3302 llvm::interleaveComma(
3303 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
3304 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3305 (*gangArgTypes)[opIdx]);
3306 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3307 p << LoopOp::getGangNumKeyword();
3308 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3309 p << LoopOp::getGangDimKeyword();
3310 else if (gangArgTypeAttr.getValue() ==
3311 mlir::acc::GangArgType::Static)
3312 p << LoopOp::getGangStaticKeyword();
3313 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
3314 ++opIdx;
3315 });
3316 p << "}";
3317 printSingleDeviceType(p, it.value());
3318 });
3319 }
3320 p << ")";
3321}
3322
3324 std::optional<mlir::ArrayAttr> segments,
3325 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3326 if (!segments)
3327 return false;
3328 for (auto attr : *segments) {
3329 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3330 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3331 return true;
3332 }
3333 return false;
3334}
3335
3336/// Check for duplicates in the DeviceType array attribute.
3337/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3338static std::optional<mlir::acc::DeviceType>
3339checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3340 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3341 if (!deviceTypes)
3342 return std::nullopt;
3343 for (auto attr : deviceTypes) {
3344 auto deviceTypeAttr =
3345 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3346 if (!deviceTypeAttr)
3347 return mlir::acc::DeviceType::None;
3348 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3349 return deviceTypeAttr.getValue();
3350 }
3351 return std::nullopt;
3352}
3353
3354LogicalResult acc::LoopOp::verify() {
3355 if (getUpperbound().size() != getStep().size())
3356 return emitError() << "number of upperbounds expected to be the same as "
3357 "number of steps";
3358
3359 if (getUpperbound().size() != getLowerbound().size())
3360 return emitError() << "number of upperbounds expected to be the same as "
3361 "number of lowerbounds";
3362
3363 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3364 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3365 return emitError() << "inclusiveUpperbound size is expected to be the same"
3366 << " as upperbound size";
3367
3368 // Check collapse
3369 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3370 return emitOpError() << "collapse device_type attr must be define when"
3371 << " collapse attr is present";
3372
3373 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3374 getCollapseAttr().getValue().size() !=
3375 getCollapseDeviceTypeAttr().getValue().size())
3376 return emitOpError() << "collapse attribute count must match collapse"
3377 << " device_type count";
3378 if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3379 return emitOpError() << "duplicate device_type `"
3380 << acc::stringifyDeviceType(*duplicateDeviceType)
3381 << "` found in collapseDeviceType attribute";
3382
3383 // Check gang
3384 if (!getGangOperands().empty()) {
3385 if (!getGangOperandsArgType())
3386 return emitOpError() << "gangOperandsArgType attribute must be defined"
3387 << " when gang operands are present";
3388
3389 if (getGangOperands().size() !=
3390 getGangOperandsArgTypeAttr().getValue().size())
3391 return emitOpError() << "gangOperandsArgType attribute count must match"
3392 << " gangOperands count";
3393 }
3394 if (getGangAttr()) {
3395 if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3396 return emitOpError() << "duplicate device_type `"
3397 << acc::stringifyDeviceType(*duplicateDeviceType)
3398 << "` found in gang attribute";
3399 }
3400
3402 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3403 getGangOperandsDeviceTypeAttr(), "gang")))
3404 return failure();
3405
3406 // Check worker
3407 if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3408 return emitOpError() << "duplicate device_type `"
3409 << acc::stringifyDeviceType(*duplicateDeviceType)
3410 << "` found in worker attribute";
3411 if (auto duplicateDeviceType =
3412 checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3413 return emitOpError() << "duplicate device_type `"
3414 << acc::stringifyDeviceType(*duplicateDeviceType)
3415 << "` found in workerNumOperandsDeviceType attribute";
3416 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3417 getWorkerNumOperandsDeviceTypeAttr(),
3418 "worker")))
3419 return failure();
3420
3421 // Check vector
3422 if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3423 return emitOpError() << "duplicate device_type `"
3424 << acc::stringifyDeviceType(*duplicateDeviceType)
3425 << "` found in vector attribute";
3426 if (auto duplicateDeviceType =
3427 checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3428 return emitOpError() << "duplicate device_type `"
3429 << acc::stringifyDeviceType(*duplicateDeviceType)
3430 << "` found in vectorOperandsDeviceType attribute";
3431 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3432 getVectorOperandsDeviceTypeAttr(),
3433 "vector")))
3434 return failure();
3435
3437 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3438 getTileOperandsDeviceTypeAttr(), "tile")))
3439 return failure();
3440
3441 // auto, independent and seq attribute are mutually exclusive.
3442 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3443 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3444 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3445 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3446 return emitError() << "only one of auto, independent, seq can be present "
3447 "at the same time";
3448 }
3449
3450 // Check that at least one of auto, independent, or seq is present
3451 // for the device-independent default clauses.
3452 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3453 return attr.getValue() == mlir::acc::DeviceType::None;
3454 };
3455 bool hasDefaultSeq =
3456 getSeqAttr()
3457 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3458 hasDeviceNone)
3459 : false;
3460 bool hasDefaultIndependent =
3461 getIndependentAttr()
3462 ? llvm::any_of(
3463 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3464 hasDeviceNone)
3465 : false;
3466 bool hasDefaultAuto =
3467 getAuto_Attr()
3468 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3469 hasDeviceNone)
3470 : false;
3471 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3472 return emitError()
3473 << "at least one of auto, independent, seq must be present";
3474 }
3475
3476 // Gang, worker and vector are incompatible with seq.
3477 if (getSeqAttr()) {
3478 for (auto attr : getSeqAttr()) {
3479 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3480 if (hasVector(deviceTypeAttr.getValue()) ||
3481 getVectorValue(deviceTypeAttr.getValue()) ||
3482 hasWorker(deviceTypeAttr.getValue()) ||
3483 getWorkerValue(deviceTypeAttr.getValue()) ||
3484 hasGang(deviceTypeAttr.getValue()) ||
3485 getGangValue(mlir::acc::GangArgType::Num,
3486 deviceTypeAttr.getValue()) ||
3487 getGangValue(mlir::acc::GangArgType::Dim,
3488 deviceTypeAttr.getValue()) ||
3489 getGangValue(mlir::acc::GangArgType::Static,
3490 deviceTypeAttr.getValue()))
3491 return emitError() << "gang, worker or vector cannot appear with seq";
3492 }
3493 }
3494
3495 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
3496 mlir::acc::PrivateRecipeOp>(
3497 *this, getPrivateOperands(), "private")))
3498 return failure();
3499
3500 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
3501 mlir::acc::FirstprivateRecipeOp>(
3502 *this, getFirstprivateOperands(), "firstprivate")))
3503 return failure();
3504
3505 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
3506 mlir::acc::ReductionRecipeOp>(
3507 *this, getReductionOperands(), "reduction")))
3508 return failure();
3509
3510 if (getCombined().has_value() &&
3511 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3512 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3513 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3514 return emitError("unexpected combined constructs attribute");
3515 }
3516
3517 // Check non-empty body().
3518 if (getRegion().empty())
3519 return emitError("expected non-empty body.");
3520
3521 if (getUnstructured()) {
3522 if (!isContainerLike())
3523 return emitError(
3524 "unstructured acc.loop must not have induction variables");
3525 } else if (isContainerLike()) {
3526 // When it is container-like - it is expected to hold a loop-like operation.
3527 // Obtain the maximum collapse count - we use this to check that there
3528 // are enough loops contained.
3529 uint64_t collapseCount = getCollapseValue().value_or(1);
3530 if (getCollapseAttr()) {
3531 for (auto collapseEntry : getCollapseAttr()) {
3532 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3533 if (intAttr.getValue().getZExtValue() > collapseCount)
3534 collapseCount = intAttr.getValue().getZExtValue();
3535 }
3536 }
3537
3538 // We want to check that we find enough loop-like operations inside.
3539 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3540 // level.
3541 mlir::Operation *expectedParent = this->getOperation();
3542 bool foundSibling = false;
3543 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3544 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3545 // This effectively checks that we are not looking at a sibling loop.
3546 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3547 expectedParent) {
3548 foundSibling = true;
3550 }
3551
3552 collapseCount--;
3553 expectedParent = op;
3554 }
3555 // We found enough contained loops.
3556 if (collapseCount == 0)
3559 });
3560
3561 if (foundSibling)
3562 return emitError("found sibling loops inside container-like acc.loop");
3563 if (collapseCount != 0)
3564 return emitError("failed to find enough loop-like operations inside "
3565 "container-like acc.loop");
3566 }
3567
3568 return success();
3569}
3570
3571unsigned LoopOp::getNumDataOperands() {
3572 return getReductionOperands().size() + getPrivateOperands().size() +
3573 getFirstprivateOperands().size();
3574}
3575
3576Value LoopOp::getDataOperand(unsigned i) {
3577 unsigned numOptional =
3578 getLowerbound().size() + getUpperbound().size() + getStep().size();
3579 numOptional += getGangOperands().size();
3580 numOptional += getVectorOperands().size();
3581 numOptional += getWorkerNumOperands().size();
3582 numOptional += getTileOperands().size();
3583 numOptional += getCacheOperands().size();
3584 return getOperand(numOptional + i);
3585}
3586
3587bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3588
3589bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3590 return hasDeviceType(getAuto_(), deviceType);
3591}
3592
3593bool LoopOp::hasIndependent() {
3594 return hasIndependent(mlir::acc::DeviceType::None);
3595}
3596
3597bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3598 return hasDeviceType(getIndependent(), deviceType);
3599}
3600
3601bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3602
3603bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3604 return hasDeviceType(getSeq(), deviceType);
3605}
3606
3607mlir::Value LoopOp::getVectorValue() {
3608 return getVectorValue(mlir::acc::DeviceType::None);
3609}
3610
3611mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3612 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3613 getVectorOperands(), deviceType);
3614}
3615
3616bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3617
3618bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3619 return hasDeviceType(getVector(), deviceType);
3620}
3621
3622mlir::Value LoopOp::getWorkerValue() {
3623 return getWorkerValue(mlir::acc::DeviceType::None);
3624}
3625
3626mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3627 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3628 getWorkerNumOperands(), deviceType);
3629}
3630
3631bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3632
3633bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3634 return hasDeviceType(getWorker(), deviceType);
3635}
3636
3637mlir::Operation::operand_range LoopOp::getTileValues() {
3638 return getTileValues(mlir::acc::DeviceType::None);
3639}
3640
3642LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3643 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3644 getTileOperandsSegments(), deviceType);
3645}
3646
3647std::optional<int64_t> LoopOp::getCollapseValue() {
3648 return getCollapseValue(mlir::acc::DeviceType::None);
3649}
3650
3651std::optional<int64_t>
3652LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3653 if (!getCollapseAttr())
3654 return std::nullopt;
3655 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3656 auto intAttr =
3657 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3658 return intAttr.getValue().getZExtValue();
3659 }
3660 return std::nullopt;
3661}
3662
3663mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3664 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3665}
3666
3667mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3668 mlir::acc::DeviceType deviceType) {
3669 if (getGangOperands().empty())
3670 return {};
3671 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3672 int32_t nbOperandsBefore = 0;
3673 for (unsigned i = 0; i < *pos; ++i)
3674 nbOperandsBefore += (*getGangOperandsSegments())[i];
3676 getGangOperands()
3677 .drop_front(nbOperandsBefore)
3678 .take_front((*getGangOperandsSegments())[*pos]);
3679
3680 int32_t argTypeIdx = nbOperandsBefore;
3681 for (auto value : values) {
3682 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3683 (*getGangOperandsArgType())[argTypeIdx]);
3684 if (gangArgTypeAttr.getValue() == gangArgType)
3685 return value;
3686 ++argTypeIdx;
3687 }
3688 }
3689 return {};
3690}
3691
3692bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3693
3694bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3695 return hasDeviceType(getGang(), deviceType);
3696}
3697
3698llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3699 return {&getRegion()};
3700}
3701
3702/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3703/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3704/// `(` ssa-id-and-type-list `)`
3705/// region
3706ParseResult
3709 SmallVectorImpl<Type> &lowerboundType,
3711 SmallVectorImpl<Type> &upperboundType,
3713 SmallVectorImpl<Type> &stepType) {
3714
3716 if (succeeded(
3717 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3718 if (parser.parseLParen() ||
3719 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3720 /*allowType=*/true) ||
3721 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3722 parser.parseOperandList(lowerbound, inductionVars.size(),
3724 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3725 parser.parseKeyword("to") || parser.parseLParen() ||
3726 parser.parseOperandList(upperbound, inductionVars.size(),
3728 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3729 parser.parseKeyword("step") || parser.parseLParen() ||
3730 parser.parseOperandList(step, inductionVars.size(),
3732 parser.parseColonTypeList(stepType) || parser.parseRParen())
3733 return failure();
3734 }
3735 return parser.parseRegion(region, inductionVars);
3736}
3737
3739 ValueRange lowerbound, TypeRange lowerboundType,
3740 ValueRange upperbound, TypeRange upperboundType,
3741 ValueRange steps, TypeRange stepType) {
3742 ValueRange regionArgs = region.front().getArguments();
3743 if (!regionArgs.empty()) {
3744 p << acc::LoopOp::getControlKeyword() << "(";
3745 llvm::interleaveComma(regionArgs, p,
3746 [&p](Value v) { p << v << " : " << v.getType(); });
3747 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3748 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3749 << " : " << stepType << ") ";
3750 }
3751 p.printRegion(region, /*printEntryBlockArgs=*/false);
3752}
3753
3754void acc::LoopOp::addSeq(MLIRContext *context,
3755 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3756 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3757 effectiveDeviceTypes));
3758}
3759
3760void acc::LoopOp::addIndependent(
3761 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3762 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3763 context, getIndependentAttr(), effectiveDeviceTypes));
3764}
3765
3766void acc::LoopOp::addAuto(MLIRContext *context,
3767 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3768 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3769 effectiveDeviceTypes));
3770}
3771
3772void acc::LoopOp::setCollapseForDeviceTypes(
3773 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3774 llvm::APInt value) {
3777
3778 assert((getCollapseAttr() == nullptr) ==
3779 (getCollapseDeviceTypeAttr() == nullptr));
3780 assert(value.getBitWidth() == 64);
3781
3782 if (getCollapseAttr()) {
3783 for (const auto &existing :
3784 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3785 newValues.push_back(std::get<0>(existing));
3786 newDeviceTypes.push_back(std::get<1>(existing));
3787 }
3788 }
3789
3790 if (effectiveDeviceTypes.empty()) {
3791 // If the effective device-types list is empty, this is before there are any
3792 // being applied by device_type, so this should be added as a 'none'.
3793 newValues.push_back(
3794 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3795 newDeviceTypes.push_back(
3796 acc::DeviceTypeAttr::get(context, DeviceType::None));
3797 } else {
3798 for (DeviceType dt : effectiveDeviceTypes) {
3799 newValues.push_back(
3800 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3801 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3802 }
3803 }
3804
3805 setCollapseAttr(ArrayAttr::get(context, newValues));
3806 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3807}
3808
3809void acc::LoopOp::setTileForDeviceTypes(
3810 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3811 ValueRange values) {
3813 if (getTileOperandsSegments())
3814 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3815
3816 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3817 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3818 getTileOperandsMutable(), segments));
3819
3820 setTileOperandsSegments(segments);
3821}
3822
3823void acc::LoopOp::addVectorOperand(
3824 MLIRContext *context, mlir::Value newValue,
3825 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3826 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3827 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3828 newValue, getVectorOperandsMutable()));
3829}
3830
3831void acc::LoopOp::addEmptyVector(
3832 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3833 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3834 effectiveDeviceTypes));
3835}
3836
3837void acc::LoopOp::addWorkerNumOperand(
3838 MLIRContext *context, mlir::Value newValue,
3839 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3840 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3841 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3842 newValue, getWorkerNumOperandsMutable()));
3843}
3844
3845void acc::LoopOp::addEmptyWorker(
3846 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3847 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3848 effectiveDeviceTypes));
3849}
3850
3851void acc::LoopOp::addEmptyGang(
3852 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3853 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3854 effectiveDeviceTypes));
3855}
3856
3857bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3858 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3859 return attr.getValue() == dt;
3860 };
3861 auto testFromArr = [=](ArrayAttr arr) -> bool {
3862 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3863 };
3864
3865 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3866 return true;
3867 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3868 return true;
3869 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3870 return true;
3871
3872 return false;
3873}
3874
3875bool acc::LoopOp::hasDefaultGangWorkerVector() {
3876 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3877 hasGang() || getGangValue(GangArgType::Num) ||
3878 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3879}
3880
3881acc::LoopParMode
3882acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3883 if (hasSeq(deviceType))
3884 return LoopParMode::loop_seq;
3885 if (hasAuto(deviceType))
3886 return LoopParMode::loop_auto;
3887 if (hasIndependent(deviceType))
3888 return LoopParMode::loop_independent;
3889 if (hasSeq())
3890 return LoopParMode::loop_seq;
3891 if (hasAuto())
3892 return LoopParMode::loop_auto;
3893 assert(hasIndependent() &&
3894 "loop must have default auto, seq, or independent");
3895 return LoopParMode::loop_independent;
3896}
3897
3898void acc::LoopOp::addGangOperands(
3899 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3902 if (std::optional<ArrayRef<int32_t>> existingSegments =
3903 getGangOperandsSegments())
3904 llvm::copy(*existingSegments, std::back_inserter(segments));
3905
3906 unsigned beforeCount = segments.size();
3907
3908 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3909 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3910 getGangOperandsMutable(), segments));
3911
3912 setGangOperandsSegments(segments);
3913
3914 // This is a bit of extra work to make sure we update the 'types' correctly by
3915 // adding to the types collection the correct number of times. We could
3916 // potentially add something similar to the
3917 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3918 // excessive for a one-off case.
3919 unsigned numAdded = segments.size() - beforeCount;
3920
3921 if (numAdded > 0) {
3923 if (getGangOperandsArgTypeAttr())
3924 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3925
3926 for (auto i : llvm::index_range(0u, numAdded)) {
3927 llvm::transform(argTypes, std::back_inserter(gangTypes),
3928 [=](mlir::acc::GangArgType gangTy) {
3929 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3930 });
3931 (void)i;
3932 }
3933
3934 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3935 }
3936}
3937
3938void acc::LoopOp::addPrivatization(MLIRContext *context,
3939 mlir::acc::PrivateOp op,
3940 mlir::acc::PrivateRecipeOp recipe) {
3941 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3942 getPrivateOperandsMutable().append(op.getResult());
3943}
3944
3945void acc::LoopOp::addFirstPrivatization(
3946 MLIRContext *context, mlir::acc::FirstprivateOp op,
3947 mlir::acc::FirstprivateRecipeOp recipe) {
3948 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3949 getFirstprivateOperandsMutable().append(op.getResult());
3950}
3951
3952void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3953 mlir::acc::ReductionRecipeOp recipe) {
3954 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3955 getReductionOperandsMutable().append(op.getResult());
3956}
3957
3958//===----------------------------------------------------------------------===//
3959// DataOp
3960//===----------------------------------------------------------------------===//
3961
3962LogicalResult acc::DataOp::verify() {
3963 // 2.6.5. Data Construct restriction
3964 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3965 // attach, or default clause must appear on a data construct.
3966 if (getOperands().empty() && !getDefaultAttr())
3967 return emitError("at least one operand or the default attribute "
3968 "must appear on the data operation");
3969
3970 for (mlir::Value operand : getDataClauseOperands())
3971 if (isa<BlockArgument>(operand) ||
3972 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3973 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3974 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3975 operand.getDefiningOp()))
3976 return emitError("expect data entry/exit operation or acc.getdeviceptr "
3977 "as defining op");
3978
3980 return failure();
3981
3982 return success();
3983}
3984
3985unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3986
3987Value DataOp::getDataOperand(unsigned i) {
3988 unsigned numOptional = getIfCond() ? 1 : 0;
3989 numOptional += getAsyncOperands().size() ? 1 : 0;
3990 numOptional += getWaitOperands().size();
3991 return getOperand(numOptional + i);
3992}
3993
3994bool acc::DataOp::hasAsyncOnly() {
3995 return hasAsyncOnly(mlir::acc::DeviceType::None);
3996}
3997
3998bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3999 return hasDeviceType(getAsyncOnly(), deviceType);
4000}
4001
4002mlir::Value DataOp::getAsyncValue() {
4003 return getAsyncValue(mlir::acc::DeviceType::None);
4004}
4005
4006mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4008 getAsyncOperands(), deviceType);
4009}
4010
4011bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
4012
4013bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4014 return hasDeviceType(getWaitOnly(), deviceType);
4015}
4016
4017mlir::Operation::operand_range DataOp::getWaitValues() {
4018 return getWaitValues(mlir::acc::DeviceType::None);
4019}
4020
4022DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4024 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4025 getHasWaitDevnum(), deviceType);
4026}
4027
4028mlir::Value DataOp::getWaitDevnum() {
4029 return getWaitDevnum(mlir::acc::DeviceType::None);
4030}
4031
4032mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4033 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4034 getWaitOperandsSegments(), getHasWaitDevnum(),
4035 deviceType);
4036}
4037
4038void acc::DataOp::addAsyncOnly(
4039 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4040 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4041 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4042}
4043
4044void acc::DataOp::addAsyncOperand(
4045 MLIRContext *context, mlir::Value newValue,
4046 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4047 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4048 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4049 getAsyncOperandsMutable()));
4050}
4051
4052void acc::DataOp::addWaitOnly(MLIRContext *context,
4053 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4054 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4055 effectiveDeviceTypes));
4056}
4057
4058void acc::DataOp::addWaitOperands(
4059 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4060 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4061
4063 if (getWaitOperandsSegments())
4064 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4065
4066 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4067 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4068 getWaitOperandsMutable(), segments));
4069 setWaitOperandsSegments(segments);
4070
4072 if (getHasWaitDevnumAttr())
4073 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4074 hasDevnums.insert(
4075 hasDevnums.end(),
4076 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4077 mlir::BoolAttr::get(context, hasDevnum));
4078 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4079}
4080
4081//===----------------------------------------------------------------------===//
4082// ExitDataOp
4083//===----------------------------------------------------------------------===//
4084
4085LogicalResult acc::ExitDataOp::verify() {
4086 // 2.6.6. Data Exit Directive restriction
4087 // At least one copyout, delete, or detach clause must appear on an exit data
4088 // directive.
4089 if (getDataClauseOperands().empty())
4090 return emitError("at least one operand must be present in dataOperands on "
4091 "the exit data operation");
4092
4093 // The async attribute represent the async clause without value. Therefore the
4094 // attribute and operand cannot appear at the same time.
4095 if (getAsyncOperand() && getAsync())
4096 return emitError("async attribute cannot appear with asyncOperand");
4097
4098 // The wait attribute represent the wait clause without values. Therefore the
4099 // attribute and operands cannot appear at the same time.
4100 if (!getWaitOperands().empty() && getWait())
4101 return emitError("wait attribute cannot appear with waitOperands");
4102
4103 if (getWaitDevnum() && getWaitOperands().empty())
4104 return emitError("wait_devnum cannot appear without waitOperands");
4105
4106 return success();
4107}
4108
4109unsigned ExitDataOp::getNumDataOperands() {
4110 return getDataClauseOperands().size();
4111}
4112
4113Value ExitDataOp::getDataOperand(unsigned i) {
4114 unsigned numOptional = getIfCond() ? 1 : 0;
4115 numOptional += getAsyncOperand() ? 1 : 0;
4116 numOptional += getWaitDevnum() ? 1 : 0;
4117 return getOperand(getWaitOperands().size() + numOptional + i);
4118}
4119
4120void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4121 MLIRContext *context) {
4122 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
4123}
4124
4125void ExitDataOp::addAsyncOnly(MLIRContext *context,
4126 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4127 assert(effectiveDeviceTypes.empty());
4128 assert(!getAsyncAttr());
4129 assert(!getAsyncOperand());
4130
4131 setAsyncAttr(mlir::UnitAttr::get(context));
4132}
4133
4134void ExitDataOp::addAsyncOperand(
4135 MLIRContext *context, mlir::Value newValue,
4136 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4137 assert(effectiveDeviceTypes.empty());
4138 assert(!getAsyncAttr());
4139 assert(!getAsyncOperand());
4140
4141 getAsyncOperandMutable().append(newValue);
4142}
4143
4144void ExitDataOp::addWaitOnly(MLIRContext *context,
4145 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4146 assert(effectiveDeviceTypes.empty());
4147 assert(!getWaitAttr());
4148 assert(getWaitOperands().empty());
4149 assert(!getWaitDevnum());
4150
4151 setWaitAttr(mlir::UnitAttr::get(context));
4152}
4153
4154void ExitDataOp::addWaitOperands(
4155 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4156 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4157 assert(effectiveDeviceTypes.empty());
4158 assert(!getWaitAttr());
4159 assert(getWaitOperands().empty());
4160 assert(!getWaitDevnum());
4161
4162 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4163 // operands list.
4164 if (hasDevnum) {
4165 getWaitDevnumMutable().append(newValues.front());
4166 newValues = newValues.drop_front();
4167 }
4168
4169 getWaitOperandsMutable().append(newValues);
4170}
4171
4172//===----------------------------------------------------------------------===//
4173// EnterDataOp
4174//===----------------------------------------------------------------------===//
4175
4176LogicalResult acc::EnterDataOp::verify() {
4177 // 2.6.6. Data Enter Directive restriction
4178 // At least one copyin, create, or attach clause must appear on an enter data
4179 // directive.
4180 if (getDataClauseOperands().empty())
4181 return emitError("at least one operand must be present in dataOperands on "
4182 "the enter data operation");
4183
4184 // The async attribute represent the async clause without value. Therefore the
4185 // attribute and operand cannot appear at the same time.
4186 if (getAsyncOperand() && getAsync())
4187 return emitError("async attribute cannot appear with asyncOperand");
4188
4189 // The wait attribute represent the wait clause without values. Therefore the
4190 // attribute and operands cannot appear at the same time.
4191 if (!getWaitOperands().empty() && getWait())
4192 return emitError("wait attribute cannot appear with waitOperands");
4193
4194 if (getWaitDevnum() && getWaitOperands().empty())
4195 return emitError("wait_devnum cannot appear without waitOperands");
4196
4197 for (mlir::Value operand : getDataClauseOperands())
4198 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
4199 operand.getDefiningOp()))
4200 return emitError("expect data entry operation as defining op");
4201
4202 return success();
4203}
4204
4205unsigned EnterDataOp::getNumDataOperands() {
4206 return getDataClauseOperands().size();
4207}
4208
4209Value EnterDataOp::getDataOperand(unsigned i) {
4210 unsigned numOptional = getIfCond() ? 1 : 0;
4211 numOptional += getAsyncOperand() ? 1 : 0;
4212 numOptional += getWaitDevnum() ? 1 : 0;
4213 return getOperand(getWaitOperands().size() + numOptional + i);
4214}
4215
4216void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
4217 MLIRContext *context) {
4218 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
4219}
4220
4221void EnterDataOp::addAsyncOnly(
4222 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4223 assert(effectiveDeviceTypes.empty());
4224 assert(!getAsyncAttr());
4225 assert(!getAsyncOperand());
4226
4227 setAsyncAttr(mlir::UnitAttr::get(context));
4228}
4229
4230void EnterDataOp::addAsyncOperand(
4231 MLIRContext *context, mlir::Value newValue,
4232 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4233 assert(effectiveDeviceTypes.empty());
4234 assert(!getAsyncAttr());
4235 assert(!getAsyncOperand());
4236
4237 getAsyncOperandMutable().append(newValue);
4238}
4239
4240void EnterDataOp::addWaitOnly(MLIRContext *context,
4241 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4242 assert(effectiveDeviceTypes.empty());
4243 assert(!getWaitAttr());
4244 assert(getWaitOperands().empty());
4245 assert(!getWaitDevnum());
4246
4247 setWaitAttr(mlir::UnitAttr::get(context));
4248}
4249
4250void EnterDataOp::addWaitOperands(
4251 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4252 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4253 assert(effectiveDeviceTypes.empty());
4254 assert(!getWaitAttr());
4255 assert(getWaitOperands().empty());
4256 assert(!getWaitDevnum());
4257
4258 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4259 // operands list.
4260 if (hasDevnum) {
4261 getWaitDevnumMutable().append(newValues.front());
4262 newValues = newValues.drop_front();
4263 }
4264
4265 getWaitOperandsMutable().append(newValues);
4266}
4267
4268//===----------------------------------------------------------------------===//
4269// AtomicReadOp
4270//===----------------------------------------------------------------------===//
4271
4272LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
4273
4274//===----------------------------------------------------------------------===//
4275// AtomicWriteOp
4276//===----------------------------------------------------------------------===//
4277
4278LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
4279
4280//===----------------------------------------------------------------------===//
4281// AtomicUpdateOp
4282//===----------------------------------------------------------------------===//
4283
4284LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4285 PatternRewriter &rewriter) {
4286 if (op.isNoOp()) {
4287 rewriter.eraseOp(op);
4288 return success();
4289 }
4290
4291 if (Value writeVal = op.getWriteOpVal()) {
4292 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
4293 op.getIfCond());
4294 return success();
4295 }
4296
4297 return failure();
4298}
4299
4300LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
4301
4302LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4303
4304//===----------------------------------------------------------------------===//
4305// AtomicCaptureOp
4306//===----------------------------------------------------------------------===//
4307
4308AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4309 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4310 return op;
4311 return dyn_cast<AtomicReadOp>(getSecondOp());
4312}
4313
4314AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4315 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4316 return op;
4317 return dyn_cast<AtomicWriteOp>(getSecondOp());
4318}
4319
4320AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4321 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4322 return op;
4323 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4324}
4325
4326LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
4327
4328//===----------------------------------------------------------------------===//
4329// DeclareEnterOp
4330//===----------------------------------------------------------------------===//
4331
4332template <typename Op>
4333static LogicalResult
4335 bool requireAtLeastOneOperand = true) {
4336 if (operands.empty() && requireAtLeastOneOperand)
4337 return emitError(
4338 op->getLoc(),
4339 "at least one operand must appear on the declare operation");
4340
4341 for (mlir::Value operand : operands) {
4342 if (isa<BlockArgument>(operand) ||
4343 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4344 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4345 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4346 operand.getDefiningOp()))
4347 return op.emitError(
4348 "expect valid declare data entry operation or acc.getdeviceptr "
4349 "as defining op");
4350
4351 mlir::Value var{getVar(operand.getDefiningOp())};
4352 assert(var && "declare operands can only be data entry operations which "
4353 "must have var");
4354 (void)var;
4355 std::optional<mlir::acc::DataClause> dataClauseOptional{
4356 getDataClause(operand.getDefiningOp())};
4357 assert(dataClauseOptional.has_value() &&
4358 "declare operands can only be data entry operations which must have "
4359 "dataClause");
4360 (void)dataClauseOptional;
4361 }
4362
4363 return success();
4364}
4365
4366LogicalResult acc::DeclareEnterOp::verify() {
4367 return checkDeclareOperands(*this, this->getDataClauseOperands());
4368}
4369
4370//===----------------------------------------------------------------------===//
4371// DeclareExitOp
4372//===----------------------------------------------------------------------===//
4373
4374LogicalResult acc::DeclareExitOp::verify() {
4375 if (getToken())
4376 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4377 /*requireAtLeastOneOperand=*/false);
4378 return checkDeclareOperands(*this, this->getDataClauseOperands());
4379}
4380
4381//===----------------------------------------------------------------------===//
4382// DeclareOp
4383//===----------------------------------------------------------------------===//
4384
4385LogicalResult acc::DeclareOp::verify() {
4386 return checkDeclareOperands(*this, this->getDataClauseOperands());
4387}
4388
4389//===----------------------------------------------------------------------===//
4390// RoutineOp
4391//===----------------------------------------------------------------------===//
4392
4393static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4394 acc::DeviceType dtype) {
4395 unsigned parallelism = 0;
4396 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4397 parallelism += op.hasWorker(dtype) ? 1 : 0;
4398 parallelism += op.hasVector(dtype) ? 1 : 0;
4399 parallelism += op.hasSeq(dtype) ? 1 : 0;
4400 return parallelism;
4401}
4402
4403LogicalResult acc::RoutineOp::verify() {
4404 unsigned baseParallelism =
4405 getParallelismForDeviceType(*this, acc::DeviceType::None);
4406
4407 if (baseParallelism > 1)
4408 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4409 "be present at the same time";
4410
4411 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4412 ++dtypeInt) {
4413 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4414 if (dtype == acc::DeviceType::None)
4415 continue;
4416 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4417
4418 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4419 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4420 "be present at the same time for device_type `"
4421 << acc::stringifyDeviceType(dtype) << "`";
4422 }
4423
4424 return success();
4425}
4426
4427static ParseResult parseBindName(OpAsmParser &parser,
4428 mlir::ArrayAttr &bindIdName,
4429 mlir::ArrayAttr &bindStrName,
4430 mlir::ArrayAttr &deviceIdTypes,
4431 mlir::ArrayAttr &deviceStrTypes) {
4432 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4433 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4434 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4435 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4436
4437 if (failed(parser.parseCommaSeparatedList([&]() {
4438 mlir::Attribute newAttr;
4439 bool isSymbolRefAttr;
4440 auto parseResult = parser.parseAttribute(newAttr);
4441 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4442 bindIdNameAttrs.push_back(symbolRefAttr);
4443 isSymbolRefAttr = true;
4444 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4445 bindStrNameAttrs.push_back(stringAttr);
4446 isSymbolRefAttr = false;
4447 }
4448 if (parseResult)
4449 return failure();
4450 if (failed(parser.parseOptionalLSquare())) {
4451 if (isSymbolRefAttr) {
4452 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4453 parser.getContext(), mlir::acc::DeviceType::None));
4454 } else {
4455 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4456 parser.getContext(), mlir::acc::DeviceType::None));
4457 }
4458 } else {
4459 if (isSymbolRefAttr) {
4460 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4461 parser.parseRSquare())
4462 return failure();
4463 } else {
4464 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4465 parser.parseRSquare())
4466 return failure();
4467 }
4468 }
4469 return success();
4470 })))
4471 return failure();
4472
4473 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4474 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4475 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4476 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4477
4478 return success();
4479}
4480
4482 std::optional<mlir::ArrayAttr> bindIdName,
4483 std::optional<mlir::ArrayAttr> bindStrName,
4484 std::optional<mlir::ArrayAttr> deviceIdTypes,
4485 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4486 // Create combined vectors for all bind names and device types
4489
4490 // Append bindIdName and deviceIdTypes
4491 if (hasDeviceTypeValues(deviceIdTypes)) {
4492 allBindNames.append(bindIdName->begin(), bindIdName->end());
4493 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4494 }
4495
4496 // Append bindStrName and deviceStrTypes
4497 if (hasDeviceTypeValues(deviceStrTypes)) {
4498 allBindNames.append(bindStrName->begin(), bindStrName->end());
4499 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4500 }
4501
4502 // Print the combined sequence
4503 if (!allBindNames.empty())
4504 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4505 [&](const auto &pair) {
4506 p << std::get<0>(pair);
4507 printSingleDeviceType(p, std::get<1>(pair));
4508 });
4509}
4510
4511static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4512 mlir::ArrayAttr &gang,
4513 mlir::ArrayAttr &gangDim,
4514 mlir::ArrayAttr &gangDimDeviceTypes) {
4515
4516 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4517 gangDimDeviceTypeAttrs;
4518 bool needCommaBeforeOperands = false;
4519
4520 // Gang keyword only
4521 if (failed(parser.parseOptionalLParen())) {
4522 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4523 parser.getContext(), mlir::acc::DeviceType::None));
4524 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4525 return success();
4526 }
4527
4528 // Parse keyword only attributes
4529 if (succeeded(parser.parseOptionalLSquare())) {
4530 if (failed(parser.parseCommaSeparatedList([&]() {
4531 if (parser.parseAttribute(gangAttrs.emplace_back()))
4532 return failure();
4533 return success();
4534 })))
4535 return failure();
4536 if (parser.parseRSquare())
4537 return failure();
4538 needCommaBeforeOperands = true;
4539 }
4540
4541 if (needCommaBeforeOperands && failed(parser.parseComma()))
4542 return failure();
4543
4544 if (failed(parser.parseCommaSeparatedList([&]() {
4545 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4546 parser.parseColon() ||
4547 parser.parseAttribute(gangDimAttrs.emplace_back()))
4548 return failure();
4549 if (succeeded(parser.parseOptionalLSquare())) {
4550 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4551 parser.parseRSquare())
4552 return failure();
4553 } else {
4554 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4555 parser.getContext(), mlir::acc::DeviceType::None));
4556 }
4557 return success();
4558 })))
4559 return failure();
4560
4561 if (failed(parser.parseRParen()))
4562 return failure();
4563
4564 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4565 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4566 gangDimDeviceTypes =
4567 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4568
4569 return success();
4570}
4571
4573 std::optional<mlir::ArrayAttr> gang,
4574 std::optional<mlir::ArrayAttr> gangDim,
4575 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4576
4577 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4578 gang->size() == 1) {
4579 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4580 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4581 return;
4582 }
4583
4584 p << "(";
4585
4586 printDeviceTypes(p, gang);
4587
4588 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4589 p << ", ";
4590
4591 if (hasDeviceTypeValues(gangDimDeviceTypes))
4592 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4593 [&](const auto &pair) {
4594 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4595 p << std::get<0>(pair);
4596 printSingleDeviceType(p, std::get<1>(pair));
4597 });
4598
4599 p << ")";
4600}
4601
4602static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4603 mlir::ArrayAttr &deviceTypes) {
4605 // Keyword only
4606 if (failed(parser.parseOptionalLParen())) {
4607 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4608 parser.getContext(), mlir::acc::DeviceType::None));
4609 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4610 return success();
4611 }
4612
4613 // Parse device type attributes
4614 if (succeeded(parser.parseOptionalLSquare())) {
4615 if (failed(parser.parseCommaSeparatedList([&]() {
4616 if (parser.parseAttribute(attributes.emplace_back()))
4617 return failure();
4618 return success();
4619 })))
4620 return failure();
4621 if (parser.parseRSquare() || parser.parseRParen())
4622 return failure();
4623 }
4624 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4625 return success();
4626}
4627
4628static void
4630 std::optional<mlir::ArrayAttr> deviceTypes) {
4631
4632 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4633 auto deviceTypeAttr =
4634 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4635 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4636 return;
4637 }
4638
4639 if (!hasDeviceTypeValues(deviceTypes))
4640 return;
4641
4642 p << "([";
4643 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4644 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4645 p << dTypeAttr;
4646 });
4647 p << "])";
4648}
4649
4650bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4651
4652bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4653 return hasDeviceType(getWorker(), deviceType);
4654}
4655
4656bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4657
4658bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4659 return hasDeviceType(getVector(), deviceType);
4660}
4661
4662bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4663
4664bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4665 return hasDeviceType(getSeq(), deviceType);
4666}
4667
4668std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4669RoutineOp::getBindNameValue() {
4670 return getBindNameValue(mlir::acc::DeviceType::None);
4671}
4672
4673std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4674RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4675 if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
4676 !hasDeviceTypeValues(getBindStrNameDeviceType())) {
4677 return std::nullopt;
4678 }
4679
4680 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4681 auto attr = (*getBindIdName())[*pos];
4682 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4683 assert(symbolRefAttr && "expected SymbolRef");
4684 return symbolRefAttr;
4685 }
4686
4687 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4688 auto attr = (*getBindStrName())[*pos];
4689 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4690 assert(stringAttr && "expected String");
4691 return stringAttr;
4692 }
4693
4694 return std::nullopt;
4695}
4696
4697bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4698
4699bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4700 return hasDeviceType(getGang(), deviceType);
4701}
4702
4703std::optional<int64_t> RoutineOp::getGangDimValue() {
4704 return getGangDimValue(mlir::acc::DeviceType::None);
4705}
4706
4707std::optional<int64_t>
4708RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4709 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4710 return std::nullopt;
4711 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4712 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4713 return intAttr.getInt();
4714 }
4715 return std::nullopt;
4716}
4717
4718void RoutineOp::addSeq(MLIRContext *context,
4719 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4720 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4721 effectiveDeviceTypes));
4722}
4723
4724void RoutineOp::addVector(MLIRContext *context,
4725 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4726 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4727 effectiveDeviceTypes));
4728}
4729
4730void RoutineOp::addWorker(MLIRContext *context,
4731 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4732 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4733 effectiveDeviceTypes));
4734}
4735
4736void RoutineOp::addGang(MLIRContext *context,
4737 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4738 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4739 effectiveDeviceTypes));
4740}
4741
4742void RoutineOp::addGang(MLIRContext *context,
4743 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4744 uint64_t val) {
4747
4748 if (getGangDimAttr())
4749 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4750 if (getGangDimDeviceTypeAttr())
4751 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4752
4753 assert(dimValues.size() == deviceTypes.size());
4754
4755 if (effectiveDeviceTypes.empty()) {
4756 dimValues.push_back(
4757 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4758 deviceTypes.push_back(
4759 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4760 } else {
4761 for (DeviceType dt : effectiveDeviceTypes) {
4762 dimValues.push_back(
4763 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4764 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4765 }
4766 }
4767 assert(dimValues.size() == deviceTypes.size());
4768
4769 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4770 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4771}
4772
4773void RoutineOp::addBindStrName(MLIRContext *context,
4774 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4775 mlir::StringAttr val) {
4776 unsigned before = getBindStrNameDeviceTypeAttr()
4777 ? getBindStrNameDeviceTypeAttr().size()
4778 : 0;
4779
4780 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4781 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4782 unsigned after = getBindStrNameDeviceTypeAttr().size();
4783
4785 if (getBindStrNameAttr())
4786 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4787 for (unsigned i = 0; i < after - before; ++i)
4788 vals.push_back(val);
4789
4790 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4791}
4792
4793void RoutineOp::addBindIDName(MLIRContext *context,
4794 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4795 mlir::SymbolRefAttr val) {
4796 unsigned before =
4797 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4798
4799 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4800 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4801 unsigned after = getBindIdNameDeviceTypeAttr().size();
4802
4804 if (getBindIdNameAttr())
4805 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4806 for (unsigned i = 0; i < after - before; ++i)
4807 vals.push_back(val);
4808
4809 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4810}
4811
4812//===----------------------------------------------------------------------===//
4813// InitOp
4814//===----------------------------------------------------------------------===//
4815
4816LogicalResult acc::InitOp::verify() {
4817 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4818 return emitOpError("cannot be nested in a compute operation");
4819 return success();
4820}
4821
4822void acc::InitOp::addDeviceType(MLIRContext *context,
4823 mlir::acc::DeviceType deviceType) {
4825 if (getDeviceTypesAttr())
4826 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4827
4828 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4829 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4830}
4831
4832//===----------------------------------------------------------------------===//
4833// ShutdownOp
4834//===----------------------------------------------------------------------===//
4835
4836LogicalResult acc::ShutdownOp::verify() {
4837 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4838 return emitOpError("cannot be nested in a compute operation");
4839 return success();
4840}
4841
4842void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4843 mlir::acc::DeviceType deviceType) {
4845 if (getDeviceTypesAttr())
4846 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4847
4848 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4849 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4850}
4851
4852//===----------------------------------------------------------------------===//
4853// SetOp
4854//===----------------------------------------------------------------------===//
4855
4856LogicalResult acc::SetOp::verify() {
4857 if (getOperation()->getParentOfType<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>())
4858 return emitOpError("cannot be nested in a compute operation");
4859 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4860 return emitOpError("at least one default_async, device_num, or device_type "
4861 "operand must appear");
4862 return success();
4863}
4864
4865//===----------------------------------------------------------------------===//
4866// UpdateOp
4867//===----------------------------------------------------------------------===//
4868
4869LogicalResult acc::UpdateOp::verify() {
4870 // At least one of host or device should have a value.
4871 if (getDataClauseOperands().empty())
4872 return emitError("at least one value must be present in dataOperands");
4873
4875 getAsyncOperandsDeviceTypeAttr(),
4876 "async")))
4877 return failure();
4878
4880 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4881 getWaitOperandsDeviceTypeAttr(), "wait")))
4882 return failure();
4883
4885 return failure();
4886
4887 for (mlir::Value operand : getDataClauseOperands())
4888 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4889 operand.getDefiningOp()))
4890 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4891 "as defining op");
4892
4893 return success();
4894}
4895
4896unsigned UpdateOp::getNumDataOperands() {
4897 return getDataClauseOperands().size();
4898}
4899
4900Value UpdateOp::getDataOperand(unsigned i) {
4901 unsigned numOptional = getAsyncOperands().size();
4902 numOptional += getIfCond() ? 1 : 0;
4903 return getOperand(getWaitOperands().size() + numOptional + i);
4904}
4905
4906void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
4907 MLIRContext *context) {
4908 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
4909}
4910
4911bool UpdateOp::hasAsyncOnly() {
4912 return hasAsyncOnly(mlir::acc::DeviceType::None);
4913}
4914
4915bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4916 return hasDeviceType(getAsyncOnly(), deviceType);
4917}
4918
4919mlir::Value UpdateOp::getAsyncValue() {
4920 return getAsyncValue(mlir::acc::DeviceType::None);
4921}
4922
4923mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4925 return {};
4926
4927 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4928 return getAsyncOperands()[*pos];
4929
4930 return {};
4931}
4932
4933bool UpdateOp::hasWaitOnly() {
4934 return hasWaitOnly(mlir::acc::DeviceType::None);
4935}
4936
4937bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4938 return hasDeviceType(getWaitOnly(), deviceType);
4939}
4940
4941mlir::Operation::operand_range UpdateOp::getWaitValues() {
4942 return getWaitValues(mlir::acc::DeviceType::None);
4943}
4944
4946UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4948 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4949 getHasWaitDevnum(), deviceType);
4950}
4951
4952mlir::Value UpdateOp::getWaitDevnum() {
4953 return getWaitDevnum(mlir::acc::DeviceType::None);
4954}
4955
4956mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4957 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4958 getWaitOperandsSegments(), getHasWaitDevnum(),
4959 deviceType);
4960}
4961
4962void UpdateOp::addAsyncOnly(MLIRContext *context,
4963 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4964 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4965 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4966}
4967
4968void UpdateOp::addAsyncOperand(
4969 MLIRContext *context, mlir::Value newValue,
4970 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4971 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4972 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4973 getAsyncOperandsMutable()));
4974}
4975
4976void UpdateOp::addWaitOnly(MLIRContext *context,
4977 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4978 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4979 effectiveDeviceTypes));
4980}
4981
4982void UpdateOp::addWaitOperands(
4983 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4984 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4985
4987 if (getWaitOperandsSegments())
4988 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4989
4990 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4991 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4992 getWaitOperandsMutable(), segments));
4993 setWaitOperandsSegments(segments);
4994
4996 if (getHasWaitDevnumAttr())
4997 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4998 hasDevnums.insert(
4999 hasDevnums.end(),
5000 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
5001 mlir::BoolAttr::get(context, hasDevnum));
5002 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
5003}
5004
5005//===----------------------------------------------------------------------===//
5006// WaitOp
5007//===----------------------------------------------------------------------===//
5008
5009LogicalResult acc::WaitOp::verify() {
5010 // The async attribute represent the async clause without value. Therefore the
5011 // attribute and operand cannot appear at the same time.
5012 if (getAsyncOperand() && getAsync())
5013 return emitError("async attribute cannot appear with asyncOperand");
5014
5015 if (getWaitDevnum() && getWaitOperands().empty())
5016 return emitError("wait_devnum cannot appear without waitOperands");
5017
5018 return success();
5019}
5020
5021#define GET_OP_CLASSES
5022#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
5023
5024#define GET_ATTRDEF_CLASSES
5025#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
5026
5027#define GET_TYPEDEF_CLASSES
5028#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
5029
5030//===----------------------------------------------------------------------===//
5031// acc dialect utilities
5032//===----------------------------------------------------------------------===//
5033
5036 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
5038 accDataClauseOp)
5039 .Case<ACC_DATA_ENTRY_OPS>(
5040 [&](auto entry) { return entry.getVarPtr(); })
5041 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5042 [&](auto exit) { return exit.getVarPtr(); })
5043 .Default([&](mlir::Operation *) {
5045 })};
5046 return varPtr;
5047}
5048
5050 auto varPtr{
5052 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
5053 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5054 return varPtr;
5055}
5056
5058 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
5059 .Case<ACC_DATA_ENTRY_OPS>(
5060 [&](auto entry) { return entry.getVarType(); })
5061 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
5062 [&](auto exit) { return exit.getVarType(); })
5063 .Default([&](mlir::Operation *) { return mlir::Type(); })};
5064 return varType;
5065}
5066
5069 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
5071 accDataClauseOp)
5072 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
5073 [&](auto dataClause) { return dataClause.getAccPtr(); })
5074 .Default([&](mlir::Operation *) {
5076 })};
5077 return accPtr;
5078}
5079
5081 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
5083 [&](auto dataClause) { return dataClause.getAccVar(); })
5084 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5085 return accPtr;
5086}
5087
5089 auto varPtrPtr{
5091 .Case<ACC_DATA_ENTRY_OPS>(
5092 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
5093 .Default([&](mlir::Operation *) { return mlir::Value(); })};
5094 return varPtrPtr;
5095}
5096
5101 accDataClauseOp)
5102 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5104 dataClause.getBounds().begin(), dataClause.getBounds().end());
5105 })
5106 .Default([&](mlir::Operation *) {
5108 })};
5109 return bounds;
5110}
5111
5115 accDataClauseOp)
5116 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5118 dataClause.getAsyncOperands().begin(),
5119 dataClause.getAsyncOperands().end());
5120 })
5121 .Default([&](mlir::Operation *) {
5123 });
5124}
5125
5126mlir::ArrayAttr
5129 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
5130 return dataClause.getAsyncOperandsDeviceTypeAttr();
5131 })
5132 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5133}
5134
5135mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
5138 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
5139 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
5140}
5141
5142std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
5143 auto name{
5145 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
5146 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
5147 return {};
5148 })};
5149 return name;
5150}
5151
5152std::optional<mlir::acc::DataClause>
5154 auto dataClause{
5156 accDataEntryOp)
5157 .Case<ACC_DATA_ENTRY_OPS>(
5158 [&](auto entry) { return entry.getDataClause(); })
5159 .Default([&](mlir::Operation *) { return std::nullopt; })};
5160 return dataClause;
5161}
5162
5164 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
5165 .Case<ACC_DATA_ENTRY_OPS>(
5166 [&](auto entry) { return entry.getImplicit(); })
5167 .Default([&](mlir::Operation *) { return false; })};
5168 return implicit;
5169}
5170
5172 auto dataOperands{
5175 [&](auto entry) { return entry.getDataClauseOperands(); })
5176 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
5177 return dataOperands;
5178}
5179
5182 auto dataOperands{
5185 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
5186 .Default([&](mlir::Operation *) { return nullptr; })};
5187 return dataOperands;
5188}
5189
5190mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
5191 auto recipe{
5193 .Case<ACC_DATA_ENTRY_OPS>(
5194 [&](auto entry) { return entry.getRecipeAttr(); })
5195 .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
5196 return recipe;
5197}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4572
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1459
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:3323
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:1942
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4427
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
Definition OpenACC.cpp:826
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:599
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2438
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
Definition OpenACC.cpp:819
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:752
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:583
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:721
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2449
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2354
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:525
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4629
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:3132
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2689
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:3339
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:4334
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:659
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2643
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:543
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:641
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:675
static void addOperandEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, MutableOperandRange operand)
Helper to add an effect on an operand, referenced by its mutable range.
Definition OpenACC.cpp:1224
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3707
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1897
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2485
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:2024
static void addResultEffect(SmallVectorImpl< SideEffects::EffectInstance< MemoryEffects::Effect > > &effects, Value result)
Helper to add an effect on a result value.
Definition OpenACC.cpp:1234
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:667
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:730
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:554
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:567
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2224
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:422
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:706
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3738
static ValueRange getSingleRegionSuccessorInputs(Operation *op, RegionSuccessor successor)
Definition OpenACC.cpp:433
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4602
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4511
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2337
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2512
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2628
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2291
static bool isEnclosedIntoComputeOp(mlir::Operation *op)
Definition OpenACC.cpp:1218
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2604
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:796
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:3151
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1677
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2673
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:2268
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
Definition OpenACC.cpp:685
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
Definition OpenACC.cpp:1911
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2585
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:529
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:3278
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2523
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:764
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:619
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:1952
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4393
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2274
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2709
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4481
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
false
Parses a map_entries map type from a string format back into its numeric value.
static void replaceOpWithRegion(RewriterBase &rewriter, Operation *op, Region &region)
Replaces the given op with the contents of the given single-block region, using the operands of the b...
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:165
Operation & front()
Definition Block.h:163
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
BlockArgListType getArguments()
Definition Block.h:97
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
unsigned size() const
Returns the current size of the range.
Definition ValueRange.h:156
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:350
This class helps build Operations.
Definition Builders.h:209
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:434
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:433
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
OperandRange operand_range
Definition Operation.h:400
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:259
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:611
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:407
result_range getResults()
Definition Operation.h:444
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
static RegionSuccessor parent()
Initialize a successor that branches after/out of the parent operation.
bool isParent() const
Return true if the successor is the parent operation.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
This class represents a specific instance of an effect.
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:122
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:363
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
#define ACC_COMPUTE_CONSTRUCT_OPS
Definition OpenACC.h:59
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:70
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:47
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:55
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:5080
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:5049
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5068
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:5153
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:5181
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:5098
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:5171
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:5142
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:5163
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
Definition OpenACC.cpp:5190
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:5113
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
Definition OpenACC.h:168
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:5088
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:205
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5135
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:5057
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:5035
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:5127
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:305
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.