MLIR 22.0.0git
OpenACC.cpp
Go to the documentation of this file.
1//===- OpenACC.cpp - OpenACC MLIR Operations ------------------------------===//
2//
3// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7// =============================================================================
8
14#include "mlir/IR/Builders.h"
18#include "mlir/IR/IRMapping.h"
19#include "mlir/IR/Matchers.h"
21#include "mlir/IR/SymbolTable.h"
22#include "mlir/Support/LLVM.h"
24#include "llvm/ADT/SmallSet.h"
25#include "llvm/ADT/TypeSwitch.h"
26#include "llvm/Support/LogicalResult.h"
27#include <variant>
28
29using namespace mlir;
30using namespace acc;
31
32#include "mlir/Dialect/OpenACC/OpenACCOpsDialect.cpp.inc"
33#include "mlir/Dialect/OpenACC/OpenACCOpsEnums.cpp.inc"
34#include "mlir/Dialect/OpenACC/OpenACCOpsInterfaces.cpp.inc"
35#include "mlir/Dialect/OpenACC/OpenACCTypeInterfaces.cpp.inc"
36#include "mlir/Dialect/OpenACCMPCommon/Interfaces/OpenACCMPOpsInterfaces.cpp.inc"
37
38namespace {
39
40static bool isScalarLikeType(Type type) {
41 return type.isIntOrIndexOrFloat() || isa<ComplexType>(type);
42}
43
44/// Helper function to attach the `VarName` attribute to an operation
45/// if a variable name is provided.
46static void attachVarNameAttr(Operation *op, OpBuilder &builder,
47 StringRef varName) {
48 if (!varName.empty()) {
49 auto varNameAttr = acc::VarNameAttr::get(builder.getContext(), varName);
50 op->setAttr(acc::getVarNameAttrName(), varNameAttr);
51 }
52}
53
54template <typename T>
55struct MemRefPointerLikeModel
56 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel<T>, T> {
57 Type getElementType(Type pointer) const {
58 return cast<T>(pointer).getElementType();
59 }
60
61 mlir::acc::VariableTypeCategory
62 getPointeeTypeCategory(Type pointer, TypedValue<PointerLikeType> varPtr,
63 Type varType) const {
64 if (auto mappableTy = dyn_cast<MappableType>(varType)) {
65 return mappableTy.getTypeCategory(varPtr);
66 }
67 auto memrefTy = cast<T>(pointer);
68 if (!memrefTy.hasRank()) {
69 // This memref is unranked - aka it could have any rank, including a
70 // rank of 0 which could mean scalar. For now, return uncategorized.
71 return mlir::acc::VariableTypeCategory::uncategorized;
72 }
73
74 if (memrefTy.getRank() == 0) {
75 if (isScalarLikeType(memrefTy.getElementType())) {
76 return mlir::acc::VariableTypeCategory::scalar;
77 }
78 // Zero-rank non-scalar - need further analysis to determine the type
79 // category. For now, return uncategorized.
80 return mlir::acc::VariableTypeCategory::uncategorized;
81 }
82
83 // It has a rank - must be an array.
84 assert(memrefTy.getRank() > 0 && "rank expected to be positive");
85 return mlir::acc::VariableTypeCategory::array;
86 }
87
88 mlir::Value genAllocate(Type pointer, OpBuilder &builder, Location loc,
89 StringRef varName, Type varType, Value originalVar,
90 bool &needsFree) const {
91 auto memrefTy = cast<MemRefType>(pointer);
92
93 // Check if this is a static memref (all dimensions are known) - if yes
94 // then we can generate an alloca operation.
95 if (memrefTy.hasStaticShape()) {
96 needsFree = false; // alloca doesn't need deallocation
97 auto allocaOp = memref::AllocaOp::create(builder, loc, memrefTy);
98 attachVarNameAttr(allocaOp, builder, varName);
99 return allocaOp.getResult();
100 }
101
102 // For dynamic memrefs, extract sizes from the original variable if
103 // provided. Otherwise they cannot be handled.
104 if (originalVar && originalVar.getType() == memrefTy &&
105 memrefTy.hasRank()) {
106 SmallVector<Value> dynamicSizes;
107 for (int64_t i = 0; i < memrefTy.getRank(); ++i) {
108 if (memrefTy.isDynamicDim(i)) {
109 // Extract the size of dimension i from the original variable
110 auto indexValue = arith::ConstantIndexOp::create(builder, loc, i);
111 auto dimSize =
112 memref::DimOp::create(builder, loc, originalVar, indexValue);
113 dynamicSizes.push_back(dimSize);
114 }
115 // Note: We only add dynamic sizes to the dynamicSizes array
116 // Static dimensions are handled automatically by AllocOp
117 }
118 needsFree = true; // alloc needs deallocation
119 auto allocOp =
120 memref::AllocOp::create(builder, loc, memrefTy, dynamicSizes);
121 attachVarNameAttr(allocOp, builder, varName);
122 return allocOp.getResult();
123 }
124
125 // TODO: Unranked not yet supported.
126 return {};
127 }
128
129 bool genFree(Type pointer, OpBuilder &builder, Location loc,
130 TypedValue<PointerLikeType> varToFree, Value allocRes,
131 Type varType) const {
132 if (auto memrefValue = dyn_cast<TypedValue<MemRefType>>(varToFree)) {
133 // Use allocRes if provided to determine the allocation type
134 Value valueToInspect = allocRes ? allocRes : memrefValue;
135
136 // Walk through casts to find the original allocation
137 Value currentValue = valueToInspect;
138 Operation *originalAlloc = nullptr;
139
140 // Follow the chain of operations to find the original allocation
141 // even if a casted result is provided.
142 while (currentValue) {
143 if (auto *definingOp = currentValue.getDefiningOp()) {
144 // Check if this is an allocation operation
145 if (isa<memref::AllocOp, memref::AllocaOp>(definingOp)) {
146 originalAlloc = definingOp;
147 break;
148 }
149
150 // Check if this is a cast operation we can look through
151 if (auto castOp = dyn_cast<memref::CastOp>(definingOp)) {
152 currentValue = castOp.getSource();
153 continue;
154 }
155
156 // Check for other cast-like operations
157 if (auto reinterpretCastOp =
158 dyn_cast<memref::ReinterpretCastOp>(definingOp)) {
159 currentValue = reinterpretCastOp.getSource();
160 continue;
161 }
162
163 // If we can't look through this operation, stop
164 break;
165 }
166 // This is a block argument or similar - can't trace further.
167 break;
168 }
169
170 if (originalAlloc) {
171 if (isa<memref::AllocaOp>(originalAlloc)) {
172 // This is an alloca - no dealloc needed, but return true (success)
173 return true;
174 }
175 if (isa<memref::AllocOp>(originalAlloc)) {
176 // This is an alloc - generate dealloc on varToFree
177 memref::DeallocOp::create(builder, loc, memrefValue);
178 return true;
179 }
180 }
181 }
182
183 return false;
184 }
185
186 bool genCopy(Type pointer, OpBuilder &builder, Location loc,
187 TypedValue<PointerLikeType> destination,
188 TypedValue<PointerLikeType> source, Type varType) const {
189 // Generate a copy operation between two memrefs
190 auto destMemref = dyn_cast_if_present<TypedValue<MemRefType>>(destination);
191 auto srcMemref = dyn_cast_if_present<TypedValue<MemRefType>>(source);
192
193 // As per memref documentation, source and destination must have same
194 // element type and shape in order to be compatible. We do not want to fail
195 // with an IR verification error - thus check that before generating the
196 // copy operation.
197 if (destMemref && srcMemref &&
198 destMemref.getType().getElementType() ==
199 srcMemref.getType().getElementType() &&
200 destMemref.getType().getShape() == srcMemref.getType().getShape()) {
201 memref::CopyOp::create(builder, loc, srcMemref, destMemref);
202 return true;
203 }
204
205 return false;
206 }
207
208 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
210 Type valueType) const {
211 // Load from a memref - only valid for scalar memrefs (rank 0).
212 // This is because the address computation for memrefs is part of the load
213 // (and not computed separately), but the API does not have arguments for
214 // indexing.
215 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(srcPtr);
216 if (!memrefValue)
217 return {};
218
219 auto memrefTy = memrefValue.getType();
220
221 // Only load from scalar memrefs (rank 0)
222 if (memrefTy.getRank() != 0)
223 return {};
224
225 return memref::LoadOp::create(builder, loc, memrefValue);
226 }
227
228 bool genStore(Type pointer, OpBuilder &builder, Location loc,
229 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
230 // Store to a memref - only valid for scalar memrefs (rank 0)
231 // This is because the address computation for memrefs is part of the store
232 // (and not computed separately), but the API does not have arguments for
233 // indexing.
234 auto memrefValue = dyn_cast_if_present<TypedValue<MemRefType>>(destPtr);
235 if (!memrefValue)
236 return false;
237
238 auto memrefTy = memrefValue.getType();
239
240 // Only store to scalar memrefs (rank 0)
241 if (memrefTy.getRank() != 0)
242 return false;
243
244 memref::StoreOp::create(builder, loc, valueToStore, memrefValue);
245 return true;
246 }
247};
248
249struct LLVMPointerPointerLikeModel
250 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
251 LLVM::LLVMPointerType> {
252 Type getElementType(Type pointer) const { return Type(); }
253
254 mlir::Value genLoad(Type pointer, OpBuilder &builder, Location loc,
256 Type valueType) const {
257 // For LLVM pointers, we need the valueType to determine what to load
258 if (!valueType)
259 return {};
260
261 return LLVM::LoadOp::create(builder, loc, valueType, srcPtr);
262 }
263
264 bool genStore(Type pointer, OpBuilder &builder, Location loc,
265 Value valueToStore, TypedValue<PointerLikeType> destPtr) const {
266 LLVM::StoreOp::create(builder, loc, valueToStore, destPtr);
267 return true;
268 }
269};
270
271struct MemrefAddressOfGlobalModel
272 : public AddressOfGlobalOpInterface::ExternalModel<
273 MemrefAddressOfGlobalModel, memref::GetGlobalOp> {
274 SymbolRefAttr getSymbol(Operation *op) const {
275 auto getGlobalOp = cast<memref::GetGlobalOp>(op);
276 return getGlobalOp.getNameAttr();
277 }
278};
279
280struct MemrefGlobalVariableModel
281 : public GlobalVariableOpInterface::ExternalModel<MemrefGlobalVariableModel,
282 memref::GlobalOp> {
283 bool isConstant(Operation *op) const {
284 auto globalOp = cast<memref::GlobalOp>(op);
285 return globalOp.getConstant();
286 }
287
288 Region *getInitRegion(Operation *op) const {
289 // GlobalOp uses attributes for initialization, not regions
290 return nullptr;
291 }
292};
293
294/// Helper function for any of the times we need to modify an ArrayAttr based on
295/// a device type list. Returns a new ArrayAttr with all of the
296/// existingDeviceTypes, plus the effective new ones(or an added none if hte new
297/// list is empty).
298mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
299 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
300 llvm::ArrayRef<acc::DeviceType> newDeviceTypes) {
302 if (existingDeviceTypes)
303 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
304
305 if (newDeviceTypes.empty())
306 deviceTypes.push_back(
307 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
308
309 for (DeviceType dt : newDeviceTypes)
310 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
311
312 return mlir::ArrayAttr::get(context, deviceTypes);
313}
314
315/// Helper function for any of the times we need to add operands that are
316/// affected by a device type list. Returns a new ArrayAttr with all of the
317/// existingDeviceTypes, plus the effective new ones (or an added none, if the
318/// new list is empty). Additionally, adds the arguments to the argCollection
319/// the correct number of times. This will also update a 'segments' array, even
320/// if it won't be used.
321mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
322 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
323 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
324 mlir::MutableOperandRange argCollection,
325 llvm::SmallVector<int32_t> &segments) {
327 if (existingDeviceTypes)
328 llvm::copy(existingDeviceTypes, std::back_inserter(deviceTypes));
329
330 if (newDeviceTypes.empty()) {
331 argCollection.append(arguments);
332 segments.push_back(arguments.size());
333 deviceTypes.push_back(
334 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
335 }
336
337 for (DeviceType dt : newDeviceTypes) {
338 argCollection.append(arguments);
339 segments.push_back(arguments.size());
340 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
341 }
342
343 return mlir::ArrayAttr::get(context, deviceTypes);
344}
345
346/// Overload for when the 'segments' aren't needed.
347mlir::ArrayAttr addDeviceTypeAffectedOperandHelper(
348 MLIRContext *context, mlir::ArrayAttr existingDeviceTypes,
349 llvm::ArrayRef<acc::DeviceType> newDeviceTypes, mlir::ValueRange arguments,
350 mlir::MutableOperandRange argCollection) {
352 return addDeviceTypeAffectedOperandHelper(context, existingDeviceTypes,
353 newDeviceTypes, arguments,
354 argCollection, segments);
355}
356} // namespace
357
358//===----------------------------------------------------------------------===//
359// OpenACC operations
360//===----------------------------------------------------------------------===//
361
362void OpenACCDialect::initialize() {
363 addOperations<
364#define GET_OP_LIST
365#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
366 >();
367 addAttributes<
368#define GET_ATTRDEF_LIST
369#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
370 >();
371 addTypes<
372#define GET_TYPEDEF_LIST
373#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
374 >();
375
376 // By attaching interfaces here, we make the OpenACC dialect dependent on
377 // the other dialects. This is probably better than having dialects like LLVM
378 // and memref be dependent on OpenACC.
379 MemRefType::attachInterface<MemRefPointerLikeModel<MemRefType>>(
380 *getContext());
381 UnrankedMemRefType::attachInterface<
382 MemRefPointerLikeModel<UnrankedMemRefType>>(*getContext());
383 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
384 *getContext());
385
386 // Attach operation interfaces
387 memref::GetGlobalOp::attachInterface<MemrefAddressOfGlobalModel>(
388 *getContext());
389 memref::GlobalOp::attachInterface<MemrefGlobalVariableModel>(*getContext());
390}
391
392//===----------------------------------------------------------------------===//
393// RegionBranchOpInterface for acc.kernels / acc.parallel / acc.serial /
394// acc.kernel_environment / acc.data / acc.host_data / acc.loop
395//===----------------------------------------------------------------------===//
396
397/// Generic helper for single-region OpenACC ops that execute their body once
398/// and then return to the parent operation with their results (if any).
399static void
401 RegionBranchPoint point,
403 if (point.isParent()) {
404 regions.push_back(RegionSuccessor(&region));
405 return;
406 }
407
408 regions.push_back(RegionSuccessor(op, op->getResults()));
409}
410
411void KernelsOp::getSuccessorRegions(RegionBranchPoint point,
413 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
414 regions);
415}
416
417void ParallelOp::getSuccessorRegions(
419 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
420 regions);
421}
422
423void SerialOp::getSuccessorRegions(RegionBranchPoint point,
425 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
426 regions);
427}
428
429void KernelEnvironmentOp::getSuccessorRegions(
431 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
432 regions);
433}
434
435void DataOp::getSuccessorRegions(RegionBranchPoint point,
437 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
438 regions);
439}
440
441void HostDataOp::getSuccessorRegions(
443 getSingleRegionOpSuccessorRegions(getOperation(), getRegion(), point,
444 regions);
445}
446
447void LoopOp::getSuccessorRegions(RegionBranchPoint point,
449 // Unstructured loops: the body may contain arbitrary CFG and early exits.
450 // At the RegionBranch level, only model entry into the body and exit to the
451 // parent; any backedges are represented inside the region CFG.
452 if (getUnstructured()) {
453 if (point.isParent()) {
454 regions.push_back(RegionSuccessor(&getRegion()));
455 return;
456 }
457 regions.push_back(RegionSuccessor(getOperation(), getResults()));
458 return;
459 }
460
461 // Structured loops: model a loop-shaped region graph similar to scf.for.
462 regions.push_back(RegionSuccessor(&getRegion()));
463 regions.push_back(RegionSuccessor(getOperation(), getResults()));
464}
465
466//===----------------------------------------------------------------------===//
467// device_type support helpers
468//===----------------------------------------------------------------------===//
469
470static bool hasDeviceTypeValues(std::optional<mlir::ArrayAttr> arrayAttr) {
471 return arrayAttr && *arrayAttr && arrayAttr->size() > 0;
472}
473
474static bool hasDeviceType(std::optional<mlir::ArrayAttr> arrayAttr,
475 mlir::acc::DeviceType deviceType) {
476 if (!hasDeviceTypeValues(arrayAttr))
477 return false;
478
479 for (auto attr : *arrayAttr) {
480 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
481 if (deviceTypeAttr.getValue() == deviceType)
482 return true;
483 }
484
485 return false;
486}
487
489 std::optional<mlir::ArrayAttr> deviceTypes) {
490 if (!hasDeviceTypeValues(deviceTypes))
491 return;
492
493 p << "[";
494 llvm::interleaveComma(*deviceTypes, p,
495 [&](mlir::Attribute attr) { p << attr; });
496 p << "]";
497}
498
499static std::optional<unsigned> findSegment(ArrayAttr segments,
500 mlir::acc::DeviceType deviceType) {
501 unsigned segmentIdx = 0;
502 for (auto attr : segments) {
503 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
504 if (deviceTypeAttr.getValue() == deviceType)
505 return std::make_optional(segmentIdx);
506 ++segmentIdx;
507 }
508 return std::nullopt;
509}
510
512getValuesFromSegments(std::optional<mlir::ArrayAttr> arrayAttr,
514 std::optional<llvm::ArrayRef<int32_t>> segments,
515 mlir::acc::DeviceType deviceType) {
516 if (!arrayAttr)
517 return range.take_front(0);
518 if (auto pos = findSegment(*arrayAttr, deviceType)) {
519 int32_t nbOperandsBefore = 0;
520 for (unsigned i = 0; i < *pos; ++i)
521 nbOperandsBefore += (*segments)[i];
522 return range.drop_front(nbOperandsBefore).take_front((*segments)[*pos]);
523 }
524 return range.take_front(0);
525}
526
527static mlir::Value
528getWaitDevnumValue(std::optional<mlir::ArrayAttr> deviceTypeAttr,
530 std::optional<llvm::ArrayRef<int32_t>> segments,
531 std::optional<mlir::ArrayAttr> hasWaitDevnum,
532 mlir::acc::DeviceType deviceType) {
533 if (!hasDeviceTypeValues(deviceTypeAttr))
534 return {};
535 if (auto pos = findSegment(*deviceTypeAttr, deviceType))
536 if (hasWaitDevnum->getValue()[*pos])
537 return getValuesFromSegments(deviceTypeAttr, operands, segments,
538 deviceType)
539 .front();
540 return {};
541}
542
544getWaitValuesWithoutDevnum(std::optional<mlir::ArrayAttr> deviceTypeAttr,
546 std::optional<llvm::ArrayRef<int32_t>> segments,
547 std::optional<mlir::ArrayAttr> hasWaitDevnum,
548 mlir::acc::DeviceType deviceType) {
549 auto range =
550 getValuesFromSegments(deviceTypeAttr, operands, segments, deviceType);
551 if (range.empty())
552 return range;
553 if (auto pos = findSegment(*deviceTypeAttr, deviceType)) {
554 if (hasWaitDevnum && *hasWaitDevnum) {
555 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasWaitDevnum)[*pos]);
556 if (boolAttr.getValue())
557 return range.drop_front(1); // first value is devnum
558 }
559 }
560 return range;
561}
562
563template <typename Op>
564static LogicalResult checkWaitAndAsyncConflict(Op op) {
565 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
566 ++dtypeInt) {
567 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
568
569 // The asyncOnly attribute represent the async clause without value.
570 // Therefore the attribute and operand cannot appear at the same time.
571 if (hasDeviceType(op.getAsyncOperandsDeviceType(), dtype) &&
572 op.hasAsyncOnly(dtype))
573 return op.emitError(
574 "asyncOnly attribute cannot appear with asyncOperand");
575
576 // The wait attribute represent the wait clause without values. Therefore
577 // the attribute and operands cannot appear at the same time.
578 if (hasDeviceType(op.getWaitOperandsDeviceType(), dtype) &&
579 op.hasWaitOnly(dtype))
580 return op.emitError("wait attribute cannot appear with waitOperands");
581 }
582 return success();
583}
584
585template <typename Op>
586static LogicalResult checkVarAndVarType(Op op) {
587 if (!op.getVar())
588 return op.emitError("must have var operand");
589
590 // A variable must have a type that is either pointer-like or mappable.
591 if (!mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
592 !mlir::isa<mlir::acc::MappableType>(op.getVar().getType()))
593 return op.emitError("var must be mappable or pointer-like");
594
595 // When it is a pointer-like type, the varType must capture the target type.
596 if (mlir::isa<mlir::acc::PointerLikeType>(op.getVar().getType()) &&
597 op.getVarType() == op.getVar().getType())
598 return op.emitError("varType must capture the element type of var");
599
600 return success();
601}
602
603template <typename Op>
604static LogicalResult checkVarAndAccVar(Op op) {
605 if (op.getVar().getType() != op.getAccVar().getType())
606 return op.emitError("input and output types must match");
607
608 return success();
609}
610
611template <typename Op>
612static LogicalResult checkNoModifier(Op op) {
613 if (op.getModifiers() != acc::DataClauseModifier::none)
614 return op.emitError("no data clause modifiers are allowed");
615 return success();
616}
617
618template <typename Op>
619static LogicalResult
620checkValidModifier(Op op, acc::DataClauseModifier validModifiers) {
621 if (acc::bitEnumContainsAny(op.getModifiers(), ~validModifiers))
622 return op.emitError(
623 "invalid data clause modifiers: " +
624 acc::stringifyDataClauseModifier(op.getModifiers() & ~validModifiers));
625
626 return success();
627}
628
629template <typename OpT, typename RecipeOpT>
630static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName) {
631 // Mappable types do not need a recipe because it is possible to generate one
632 // from its API. Reject reductions though because no API is available for them
633 // at this time.
634 if (mlir::acc::isMappableType(op.getVar().getType()) &&
635 !std::is_same_v<OpT, acc::ReductionOp>)
636 return success();
637
638 mlir::SymbolRefAttr operandRecipe = op.getRecipeAttr();
639 if (!operandRecipe)
640 return op->emitOpError() << "recipe expected for " << operandName;
641
642 auto decl =
644 if (!decl)
645 return op->emitOpError()
646 << "expected symbol reference " << operandRecipe << " to point to a "
647 << operandName << " declaration";
648 return success();
649}
650
651static ParseResult parseVar(mlir::OpAsmParser &parser,
653 // Either `var` or `varPtr` keyword is required.
654 if (failed(parser.parseOptionalKeyword("varPtr"))) {
655 if (failed(parser.parseKeyword("var")))
656 return failure();
657 }
658 if (failed(parser.parseLParen()))
659 return failure();
660 if (failed(parser.parseOperand(var)))
661 return failure();
662
663 return success();
664}
665
667 mlir::Value var) {
668 if (mlir::isa<mlir::acc::PointerLikeType>(var.getType()))
669 p << "varPtr(";
670 else
671 p << "var(";
672 p.printOperand(var);
673}
674
675static ParseResult parseAccVar(mlir::OpAsmParser &parser,
677 mlir::Type &accVarType) {
678 // Either `accVar` or `accPtr` keyword is required.
679 if (failed(parser.parseOptionalKeyword("accPtr"))) {
680 if (failed(parser.parseKeyword("accVar")))
681 return failure();
682 }
683 if (failed(parser.parseLParen()))
684 return failure();
685 if (failed(parser.parseOperand(var)))
686 return failure();
687 if (failed(parser.parseColon()))
688 return failure();
689 if (failed(parser.parseType(accVarType)))
690 return failure();
691 if (failed(parser.parseRParen()))
692 return failure();
693
694 return success();
695}
696
698 mlir::Value accVar, mlir::Type accVarType) {
699 if (mlir::isa<mlir::acc::PointerLikeType>(accVar.getType()))
700 p << "accPtr(";
701 else
702 p << "accVar(";
703 p.printOperand(accVar);
704 p << " : ";
705 p.printType(accVarType);
706 p << ")";
707}
708
709static ParseResult parseVarPtrType(mlir::OpAsmParser &parser,
710 mlir::Type &varPtrType,
711 mlir::TypeAttr &varTypeAttr) {
712 if (failed(parser.parseType(varPtrType)))
713 return failure();
714 if (failed(parser.parseRParen()))
715 return failure();
716
717 if (succeeded(parser.parseOptionalKeyword("varType"))) {
718 if (failed(parser.parseLParen()))
719 return failure();
720 mlir::Type varType;
721 if (failed(parser.parseType(varType)))
722 return failure();
723 varTypeAttr = mlir::TypeAttr::get(varType);
724 if (failed(parser.parseRParen()))
725 return failure();
726 } else {
727 // Set `varType` from the element type of the type of `varPtr`.
728 if (mlir::isa<mlir::acc::PointerLikeType>(varPtrType))
729 varTypeAttr = mlir::TypeAttr::get(
730 mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType());
731 else
732 varTypeAttr = mlir::TypeAttr::get(varPtrType);
733 }
734
735 return success();
736}
737
739 mlir::Type varPtrType, mlir::TypeAttr varTypeAttr) {
740 p.printType(varPtrType);
741 p << ")";
742
743 // Print the `varType` only if it differs from the element type of
744 // `varPtr`'s type.
745 mlir::Type varType = varTypeAttr.getValue();
746 mlir::Type typeToCheckAgainst =
747 mlir::isa<mlir::acc::PointerLikeType>(varPtrType)
748 ? mlir::cast<mlir::acc::PointerLikeType>(varPtrType).getElementType()
749 : varPtrType;
750 if (typeToCheckAgainst != varType) {
751 p << " varType(";
752 p.printType(varType);
753 p << ")";
754 }
755}
756
757static ParseResult parseRecipeSym(mlir::OpAsmParser &parser,
758 mlir::SymbolRefAttr &recipeAttr) {
759 if (failed(parser.parseAttribute(recipeAttr)))
760 return failure();
761 return success();
762}
763
765 mlir::SymbolRefAttr recipeAttr) {
766 p << recipeAttr;
767}
768
769//===----------------------------------------------------------------------===//
770// DataBoundsOp
771//===----------------------------------------------------------------------===//
772LogicalResult acc::DataBoundsOp::verify() {
773 auto extent = getExtent();
774 auto upperbound = getUpperbound();
775 if (!extent && !upperbound)
776 return emitError("expected extent or upperbound.");
777 return success();
778}
779
780//===----------------------------------------------------------------------===//
781// PrivateOp
782//===----------------------------------------------------------------------===//
783LogicalResult acc::PrivateOp::verify() {
784 if (getDataClause() != acc::DataClause::acc_private)
785 return emitError(
786 "data clause associated with private operation must match its intent");
787 if (failed(checkVarAndVarType(*this)))
788 return failure();
789 if (failed(checkNoModifier(*this)))
790 return failure();
791 if (failed(
793 return failure();
794 return success();
795}
796
797//===----------------------------------------------------------------------===//
798// FirstprivateOp
799//===----------------------------------------------------------------------===//
800LogicalResult acc::FirstprivateOp::verify() {
801 if (getDataClause() != acc::DataClause::acc_firstprivate)
802 return emitError("data clause associated with firstprivate operation must "
803 "match its intent");
804 if (failed(checkVarAndVarType(*this)))
805 return failure();
806 if (failed(checkNoModifier(*this)))
807 return failure();
809 *this, "firstprivate")))
810 return failure();
811 return success();
812}
813
814//===----------------------------------------------------------------------===//
815// FirstprivateMapInitialOp
816//===----------------------------------------------------------------------===//
817LogicalResult acc::FirstprivateMapInitialOp::verify() {
818 if (getDataClause() != acc::DataClause::acc_firstprivate)
819 return emitError("data clause associated with firstprivate operation must "
820 "match its intent");
821 if (failed(checkVarAndVarType(*this)))
822 return failure();
823 if (failed(checkNoModifier(*this)))
824 return failure();
825 return success();
826}
827
828//===----------------------------------------------------------------------===//
829// ReductionOp
830//===----------------------------------------------------------------------===//
831LogicalResult acc::ReductionOp::verify() {
832 if (getDataClause() != acc::DataClause::acc_reduction)
833 return emitError("data clause associated with reduction operation must "
834 "match its intent");
835 if (failed(checkVarAndVarType(*this)))
836 return failure();
837 if (failed(checkNoModifier(*this)))
838 return failure();
840 *this, "reduction")))
841 return failure();
842 return success();
843}
844
845//===----------------------------------------------------------------------===//
846// DevicePtrOp
847//===----------------------------------------------------------------------===//
848LogicalResult acc::DevicePtrOp::verify() {
849 if (getDataClause() != acc::DataClause::acc_deviceptr)
850 return emitError("data clause associated with deviceptr operation must "
851 "match its intent");
852 if (failed(checkVarAndVarType(*this)))
853 return failure();
854 if (failed(checkVarAndAccVar(*this)))
855 return failure();
856 if (failed(checkNoModifier(*this)))
857 return failure();
858 return success();
859}
860
861//===----------------------------------------------------------------------===//
862// PresentOp
863//===----------------------------------------------------------------------===//
864LogicalResult acc::PresentOp::verify() {
865 if (getDataClause() != acc::DataClause::acc_present)
866 return emitError(
867 "data clause associated with present operation must match its intent");
868 if (failed(checkVarAndVarType(*this)))
869 return failure();
870 if (failed(checkVarAndAccVar(*this)))
871 return failure();
872 if (failed(checkNoModifier(*this)))
873 return failure();
874 return success();
875}
876
877//===----------------------------------------------------------------------===//
878// CopyinOp
879//===----------------------------------------------------------------------===//
880LogicalResult acc::CopyinOp::verify() {
881 // Test for all clauses this operation can be decomposed from:
882 if (!getImplicit() && getDataClause() != acc::DataClause::acc_copyin &&
883 getDataClause() != acc::DataClause::acc_copyin_readonly &&
884 getDataClause() != acc::DataClause::acc_copy &&
885 getDataClause() != acc::DataClause::acc_reduction)
886 return emitError(
887 "data clause associated with copyin operation must match its intent"
888 " or specify original clause this operation was decomposed from");
889 if (failed(checkVarAndVarType(*this)))
890 return failure();
891 if (failed(checkVarAndAccVar(*this)))
892 return failure();
893 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly |
894 acc::DataClauseModifier::always |
895 acc::DataClauseModifier::capture)))
896 return failure();
897 return success();
898}
899
900bool acc::CopyinOp::isCopyinReadonly() {
901 return getDataClause() == acc::DataClause::acc_copyin_readonly ||
902 acc::bitEnumContainsAny(getModifiers(),
903 acc::DataClauseModifier::readonly);
904}
905
906//===----------------------------------------------------------------------===//
907// CreateOp
908//===----------------------------------------------------------------------===//
909LogicalResult acc::CreateOp::verify() {
910 // Test for all clauses this operation can be decomposed from:
911 if (getDataClause() != acc::DataClause::acc_create &&
912 getDataClause() != acc::DataClause::acc_create_zero &&
913 getDataClause() != acc::DataClause::acc_copyout &&
914 getDataClause() != acc::DataClause::acc_copyout_zero)
915 return emitError(
916 "data clause associated with create operation must match its intent"
917 " or specify original clause this operation was decomposed from");
918 if (failed(checkVarAndVarType(*this)))
919 return failure();
920 if (failed(checkVarAndAccVar(*this)))
921 return failure();
922 // this op is the entry part of copyout, so it also needs to allow all
923 // modifiers allowed on copyout.
924 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
925 acc::DataClauseModifier::always |
926 acc::DataClauseModifier::capture)))
927 return failure();
928 return success();
929}
930
931bool acc::CreateOp::isCreateZero() {
932 // The zero modifier is encoded in the data clause.
933 return getDataClause() == acc::DataClause::acc_create_zero ||
934 getDataClause() == acc::DataClause::acc_copyout_zero ||
935 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
936}
937
938//===----------------------------------------------------------------------===//
939// NoCreateOp
940//===----------------------------------------------------------------------===//
941LogicalResult acc::NoCreateOp::verify() {
942 if (getDataClause() != acc::DataClause::acc_no_create)
943 return emitError("data clause associated with no_create operation must "
944 "match its intent");
945 if (failed(checkVarAndVarType(*this)))
946 return failure();
947 if (failed(checkVarAndAccVar(*this)))
948 return failure();
949 if (failed(checkNoModifier(*this)))
950 return failure();
951 return success();
952}
953
954//===----------------------------------------------------------------------===//
955// AttachOp
956//===----------------------------------------------------------------------===//
957LogicalResult acc::AttachOp::verify() {
958 if (getDataClause() != acc::DataClause::acc_attach)
959 return emitError(
960 "data clause associated with attach operation must match its intent");
961 if (failed(checkVarAndVarType(*this)))
962 return failure();
963 if (failed(checkVarAndAccVar(*this)))
964 return failure();
965 if (failed(checkNoModifier(*this)))
966 return failure();
967 return success();
968}
969
970//===----------------------------------------------------------------------===//
971// DeclareDeviceResidentOp
972//===----------------------------------------------------------------------===//
973
974LogicalResult acc::DeclareDeviceResidentOp::verify() {
975 if (getDataClause() != acc::DataClause::acc_declare_device_resident)
976 return emitError("data clause associated with device_resident operation "
977 "must match its intent");
978 if (failed(checkVarAndVarType(*this)))
979 return failure();
980 if (failed(checkVarAndAccVar(*this)))
981 return failure();
982 if (failed(checkNoModifier(*this)))
983 return failure();
984 return success();
985}
986
987//===----------------------------------------------------------------------===//
988// DeclareLinkOp
989//===----------------------------------------------------------------------===//
990
991LogicalResult acc::DeclareLinkOp::verify() {
992 if (getDataClause() != acc::DataClause::acc_declare_link)
993 return emitError(
994 "data clause associated with link operation must match its intent");
995 if (failed(checkVarAndVarType(*this)))
996 return failure();
997 if (failed(checkVarAndAccVar(*this)))
998 return failure();
999 if (failed(checkNoModifier(*this)))
1000 return failure();
1001 return success();
1002}
1003
1004//===----------------------------------------------------------------------===//
1005// CopyoutOp
1006//===----------------------------------------------------------------------===//
1007LogicalResult acc::CopyoutOp::verify() {
1008 // Test for all clauses this operation can be decomposed from:
1009 if (getDataClause() != acc::DataClause::acc_copyout &&
1010 getDataClause() != acc::DataClause::acc_copyout_zero &&
1011 getDataClause() != acc::DataClause::acc_copy &&
1012 getDataClause() != acc::DataClause::acc_reduction)
1013 return emitError(
1014 "data clause associated with copyout operation must match its intent"
1015 " or specify original clause this operation was decomposed from");
1016 if (!getVar() || !getAccVar())
1017 return emitError("must have both host and device pointers");
1018 if (failed(checkVarAndVarType(*this)))
1019 return failure();
1020 if (failed(checkVarAndAccVar(*this)))
1021 return failure();
1022 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1023 acc::DataClauseModifier::always |
1024 acc::DataClauseModifier::capture)))
1025 return failure();
1026 return success();
1027}
1028
1029bool acc::CopyoutOp::isCopyoutZero() {
1030 return getDataClause() == acc::DataClause::acc_copyout_zero ||
1031 acc::bitEnumContainsAny(getModifiers(), acc::DataClauseModifier::zero);
1032}
1033
1034//===----------------------------------------------------------------------===//
1035// DeleteOp
1036//===----------------------------------------------------------------------===//
1037LogicalResult acc::DeleteOp::verify() {
1038 // Test for all clauses this operation can be decomposed from:
1039 if (getDataClause() != acc::DataClause::acc_delete &&
1040 getDataClause() != acc::DataClause::acc_create &&
1041 getDataClause() != acc::DataClause::acc_create_zero &&
1042 getDataClause() != acc::DataClause::acc_copyin &&
1043 getDataClause() != acc::DataClause::acc_copyin_readonly &&
1044 getDataClause() != acc::DataClause::acc_present &&
1045 getDataClause() != acc::DataClause::acc_no_create &&
1046 getDataClause() != acc::DataClause::acc_declare_device_resident &&
1047 getDataClause() != acc::DataClause::acc_declare_link)
1048 return emitError(
1049 "data clause associated with delete operation must match its intent"
1050 " or specify original clause this operation was decomposed from");
1051 if (!getAccVar())
1052 return emitError("must have device pointer");
1053 // This op is the exit part of copyin and create - thus allow all modifiers
1054 // allowed on either case.
1055 if (failed(checkValidModifier(*this, acc::DataClauseModifier::zero |
1056 acc::DataClauseModifier::readonly |
1057 acc::DataClauseModifier::always |
1058 acc::DataClauseModifier::capture)))
1059 return failure();
1060 return success();
1061}
1062
1063//===----------------------------------------------------------------------===//
1064// DetachOp
1065//===----------------------------------------------------------------------===//
1066LogicalResult acc::DetachOp::verify() {
1067 // Test for all clauses this operation can be decomposed from:
1068 if (getDataClause() != acc::DataClause::acc_detach &&
1069 getDataClause() != acc::DataClause::acc_attach)
1070 return emitError(
1071 "data clause associated with detach operation must match its intent"
1072 " or specify original clause this operation was decomposed from");
1073 if (!getAccVar())
1074 return emitError("must have device pointer");
1075 if (failed(checkNoModifier(*this)))
1076 return failure();
1077 return success();
1078}
1079
1080//===----------------------------------------------------------------------===//
1081// HostOp
1082//===----------------------------------------------------------------------===//
1083LogicalResult acc::UpdateHostOp::verify() {
1084 // Test for all clauses this operation can be decomposed from:
1085 if (getDataClause() != acc::DataClause::acc_update_host &&
1086 getDataClause() != acc::DataClause::acc_update_self)
1087 return emitError(
1088 "data clause associated with host operation must match its intent"
1089 " or specify original clause this operation was decomposed from");
1090 if (!getVar() || !getAccVar())
1091 return emitError("must have both host and device pointers");
1092 if (failed(checkVarAndVarType(*this)))
1093 return failure();
1094 if (failed(checkVarAndAccVar(*this)))
1095 return failure();
1096 if (failed(checkNoModifier(*this)))
1097 return failure();
1098 return success();
1099}
1100
1101//===----------------------------------------------------------------------===//
1102// DeviceOp
1103//===----------------------------------------------------------------------===//
1104LogicalResult acc::UpdateDeviceOp::verify() {
1105 // Test for all clauses this operation can be decomposed from:
1106 if (getDataClause() != acc::DataClause::acc_update_device)
1107 return emitError(
1108 "data clause associated with device operation must match its intent"
1109 " or specify original clause this operation was decomposed from");
1110 if (failed(checkVarAndVarType(*this)))
1111 return failure();
1112 if (failed(checkVarAndAccVar(*this)))
1113 return failure();
1114 if (failed(checkNoModifier(*this)))
1115 return failure();
1116 return success();
1117}
1118
1119//===----------------------------------------------------------------------===//
1120// UseDeviceOp
1121//===----------------------------------------------------------------------===//
1122LogicalResult acc::UseDeviceOp::verify() {
1123 // Test for all clauses this operation can be decomposed from:
1124 if (getDataClause() != acc::DataClause::acc_use_device)
1125 return emitError(
1126 "data clause associated with use_device operation must match its intent"
1127 " or specify original clause this operation was decomposed from");
1128 if (failed(checkVarAndVarType(*this)))
1129 return failure();
1130 if (failed(checkVarAndAccVar(*this)))
1131 return failure();
1132 if (failed(checkNoModifier(*this)))
1133 return failure();
1134 return success();
1135}
1136
1137//===----------------------------------------------------------------------===//
1138// CacheOp
1139//===----------------------------------------------------------------------===//
1140LogicalResult acc::CacheOp::verify() {
1141 // Test for all clauses this operation can be decomposed from:
1142 if (getDataClause() != acc::DataClause::acc_cache &&
1143 getDataClause() != acc::DataClause::acc_cache_readonly)
1144 return emitError(
1145 "data clause associated with cache operation must match its intent"
1146 " or specify original clause this operation was decomposed from");
1147 if (failed(checkVarAndVarType(*this)))
1148 return failure();
1149 if (failed(checkVarAndAccVar(*this)))
1150 return failure();
1151 if (failed(checkValidModifier(*this, acc::DataClauseModifier::readonly)))
1152 return failure();
1153 return success();
1154}
1155
1156bool acc::CacheOp::isCacheReadonly() {
1157 return getDataClause() == acc::DataClause::acc_cache_readonly ||
1158 acc::bitEnumContainsAny(getModifiers(),
1159 acc::DataClauseModifier::readonly);
1160}
1161
1162template <typename StructureOp>
1163static ParseResult parseRegions(OpAsmParser &parser, OperationState &state,
1164 unsigned nRegions = 1) {
1165
1167 for (unsigned i = 0; i < nRegions; ++i)
1168 regions.push_back(state.addRegion());
1169
1170 for (Region *region : regions)
1171 if (parser.parseRegion(*region, /*arguments=*/{}, /*argTypes=*/{}))
1172 return failure();
1173
1174 return success();
1175}
1176
1178 return isa<ACC_COMPUTE_CONSTRUCT_AND_LOOP_OPS>(op);
1179}
1180
1181namespace {
1182/// Pattern to remove operation without region that have constant false `ifCond`
1183/// and remove the condition from the operation if the `ifCond` is a true
1184/// constant.
1185template <typename OpTy>
1186struct RemoveConstantIfCondition : public OpRewritePattern<OpTy> {
1187 using OpRewritePattern<OpTy>::OpRewritePattern;
1188
1189 LogicalResult matchAndRewrite(OpTy op,
1190 PatternRewriter &rewriter) const override {
1191 // Early return if there is no condition.
1192 Value ifCond = op.getIfCond();
1193 if (!ifCond)
1194 return failure();
1195
1196 IntegerAttr constAttr;
1197 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1198 return failure();
1199 if (constAttr.getInt())
1200 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1201 else
1202 rewriter.eraseOp(op);
1203
1204 return success();
1205 }
1206};
1207
1208/// Replaces the given op with the contents of the given single-block region,
1209/// using the operands of the block terminator to replace operation results.
1210static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
1211 Region &region, ValueRange blockArgs = {}) {
1212 assert(region.hasOneBlock() && "expected single-block region");
1213 Block *block = &region.front();
1214 Operation *terminator = block->getTerminator();
1215 ValueRange results = terminator->getOperands();
1216 rewriter.inlineBlockBefore(block, op, blockArgs);
1217 rewriter.replaceOp(op, results);
1218 rewriter.eraseOp(terminator);
1219}
1220
1221/// Pattern to remove operation with region that have constant false `ifCond`
1222/// and remove the condition from the operation if the `ifCond` is constant
1223/// true.
1224template <typename OpTy>
1225struct RemoveConstantIfConditionWithRegion : public OpRewritePattern<OpTy> {
1226 using OpRewritePattern<OpTy>::OpRewritePattern;
1227
1228 LogicalResult matchAndRewrite(OpTy op,
1229 PatternRewriter &rewriter) const override {
1230 // Early return if there is no condition.
1231 Value ifCond = op.getIfCond();
1232 if (!ifCond)
1233 return failure();
1234
1235 IntegerAttr constAttr;
1236 if (!matchPattern(ifCond, m_Constant(&constAttr)))
1237 return failure();
1238 if (constAttr.getInt())
1239 rewriter.modifyOpInPlace(op, [&]() { op.getIfCondMutable().erase(0); });
1240 else
1241 replaceOpWithRegion(rewriter, op, op.getRegion());
1242
1243 return success();
1244 }
1245};
1246
1247/// Remove empty acc.kernel_environment operations. If the operation has wait
1248/// operands, create a acc.wait operation to preserve synchronization.
1249struct RemoveEmptyKernelEnvironment
1250 : public OpRewritePattern<acc::KernelEnvironmentOp> {
1251 using OpRewritePattern<acc::KernelEnvironmentOp>::OpRewritePattern;
1252
1253 LogicalResult matchAndRewrite(acc::KernelEnvironmentOp op,
1254 PatternRewriter &rewriter) const override {
1255 assert(op->getNumRegions() == 1 && "expected op to have one region");
1256
1257 Block &block = op.getRegion().front();
1258 if (!block.empty())
1259 return failure();
1260
1261 // Conservatively disable canonicalization of empty acc.kernel_environment
1262 // operations if the wait operands in the kernel_environment cannot be fully
1263 // represented by acc.wait operation.
1264
1265 // Disable canonicalization if device type is not the default
1266 if (auto deviceTypeAttr = op.getWaitOperandsDeviceTypeAttr()) {
1267 for (auto attr : deviceTypeAttr) {
1268 if (auto dtAttr = mlir::dyn_cast<acc::DeviceTypeAttr>(attr)) {
1269 if (dtAttr.getValue() != mlir::acc::DeviceType::None)
1270 return failure();
1271 }
1272 }
1273 }
1274
1275 // Disable canonicalization if any wait segment has a devnum
1276 if (auto hasDevnumAttr = op.getHasWaitDevnumAttr()) {
1277 for (auto attr : hasDevnumAttr) {
1278 if (auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>(attr)) {
1279 if (boolAttr.getValue())
1280 return failure();
1281 }
1282 }
1283 }
1284
1285 // Disable canonicalization if there are multiple wait segments
1286 if (auto segmentsAttr = op.getWaitOperandsSegmentsAttr()) {
1287 if (segmentsAttr.size() > 1)
1288 return failure();
1289 }
1290
1291 // Remove empty kernel environment.
1292 // Preserve synchronization by creating acc.wait operation if needed.
1293 if (!op.getWaitOperands().empty() || op.getWaitOnlyAttr())
1294 rewriter.replaceOpWithNewOp<acc::WaitOp>(op, op.getWaitOperands(),
1295 /*asyncOperand=*/Value(),
1296 /*waitDevnum=*/Value(),
1297 /*async=*/nullptr,
1298 /*ifCond=*/Value());
1299 else
1300 rewriter.eraseOp(op);
1301
1302 return success();
1303 }
1304};
1305
1306//===----------------------------------------------------------------------===//
1307// Recipe Region Helpers
1308//===----------------------------------------------------------------------===//
1309
1310/// Create and populate an init region for privatization recipes.
1311/// Returns success if the region is populated, failure otherwise.
1312/// Sets needsFree to indicate if the allocated memory requires deallocation.
1313static LogicalResult createInitRegion(OpBuilder &builder, Location loc,
1314 Region &initRegion, Type varType,
1315 StringRef varName, ValueRange bounds,
1316 bool &needsFree) {
1317 // Create init block with arguments: original value + bounds
1318 SmallVector<Type> argTypes{varType};
1319 SmallVector<Location> argLocs{loc};
1320 for (Value bound : bounds) {
1321 argTypes.push_back(bound.getType());
1322 argLocs.push_back(loc);
1323 }
1324
1325 Block *initBlock = builder.createBlock(&initRegion);
1326 initBlock->addArguments(argTypes, argLocs);
1327 builder.setInsertionPointToStart(initBlock);
1328
1329 Value privatizedValue;
1330
1331 // Get the block argument that represents the original variable
1332 Value blockArgVar = initBlock->getArgument(0);
1333
1334 // Generate init region body based on variable type
1335 if (isa<MappableType>(varType)) {
1336 auto mappableTy = cast<MappableType>(varType);
1337 auto typedVar = cast<TypedValue<MappableType>>(blockArgVar);
1338 privatizedValue = mappableTy.generatePrivateInit(
1339 builder, loc, typedVar, varName, bounds, {}, needsFree);
1340 if (!privatizedValue)
1341 return failure();
1342 } else {
1343 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1344 auto pointerLikeTy = cast<PointerLikeType>(varType);
1345 // Use PointerLikeType's allocation API with the block argument
1346 privatizedValue = pointerLikeTy.genAllocate(builder, loc, varName, varType,
1347 blockArgVar, needsFree);
1348 if (!privatizedValue)
1349 return failure();
1350 }
1351
1352 // Add yield operation to init block
1353 acc::YieldOp::create(builder, loc, privatizedValue);
1354
1355 return success();
1356}
1357
1358/// Create and populate a copy region for firstprivate recipes.
1359/// Returns success if the region is populated, failure otherwise.
1360/// TODO: Handle MappableType - it does not yet have a copy API.
1361static LogicalResult createCopyRegion(OpBuilder &builder, Location loc,
1362 Region &copyRegion, Type varType,
1363 ValueRange bounds) {
1364 // Create copy block with arguments: original value + privatized value +
1365 // bounds
1366 SmallVector<Type> copyArgTypes{varType, varType};
1367 SmallVector<Location> copyArgLocs{loc, loc};
1368 for (Value bound : bounds) {
1369 copyArgTypes.push_back(bound.getType());
1370 copyArgLocs.push_back(loc);
1371 }
1372
1373 Block *copyBlock = builder.createBlock(&copyRegion);
1374 copyBlock->addArguments(copyArgTypes, copyArgLocs);
1375 builder.setInsertionPointToStart(copyBlock);
1376
1377 bool isMappable = isa<MappableType>(varType);
1378 bool isPointerLike = isa<PointerLikeType>(varType);
1379 // TODO: Handle MappableType - it does not yet have a copy API.
1380 // Otherwise, for now just fallback to pointer-like behavior.
1381 if (isMappable && !isPointerLike)
1382 return failure();
1383
1384 // Generate copy region body based on variable type
1385 if (isPointerLike) {
1386 auto pointerLikeTy = cast<PointerLikeType>(varType);
1387 Value originalArg = copyBlock->getArgument(0);
1388 Value privatizedArg = copyBlock->getArgument(1);
1389
1390 // Generate copy operation using PointerLikeType interface
1391 if (!pointerLikeTy.genCopy(
1392 builder, loc, cast<TypedValue<PointerLikeType>>(privatizedArg),
1393 cast<TypedValue<PointerLikeType>>(originalArg), varType))
1394 return failure();
1395 }
1396
1397 // Add terminator to copy block
1398 acc::TerminatorOp::create(builder, loc);
1399
1400 return success();
1401}
1402
1403/// Create and populate a destroy region for privatization recipes.
1404/// Returns success if the region is populated, failure otherwise.
1405static LogicalResult createDestroyRegion(OpBuilder &builder, Location loc,
1406 Region &destroyRegion, Type varType,
1407 Value allocRes, ValueRange bounds) {
1408 // Create destroy block with arguments: original value + privatized value +
1409 // bounds
1410 SmallVector<Type> destroyArgTypes{varType, varType};
1411 SmallVector<Location> destroyArgLocs{loc, loc};
1412 for (Value bound : bounds) {
1413 destroyArgTypes.push_back(bound.getType());
1414 destroyArgLocs.push_back(loc);
1415 }
1416
1417 Block *destroyBlock = builder.createBlock(&destroyRegion);
1418 destroyBlock->addArguments(destroyArgTypes, destroyArgLocs);
1419 builder.setInsertionPointToStart(destroyBlock);
1420
1421 auto varToFree =
1422 cast<TypedValue<PointerLikeType>>(destroyBlock->getArgument(1));
1423 if (isa<MappableType>(varType)) {
1424 auto mappableTy = cast<MappableType>(varType);
1425 if (!mappableTy.generatePrivateDestroy(builder, loc, varToFree))
1426 return failure();
1427 } else {
1428 assert(isa<PointerLikeType>(varType) && "Expected PointerLikeType");
1429 auto pointerLikeTy = cast<PointerLikeType>(varType);
1430 if (!pointerLikeTy.genFree(builder, loc, varToFree, allocRes, varType))
1431 return failure();
1432 }
1433
1434 acc::TerminatorOp::create(builder, loc);
1435 return success();
1436}
1437
1438} // namespace
1439
1440//===----------------------------------------------------------------------===//
1441// PrivateRecipeOp
1442//===----------------------------------------------------------------------===//
1443
1445 Operation *op, Region &region, StringRef regionType, StringRef regionName,
1446 Type type, bool verifyYield, bool optional = false) {
1447 if (optional && region.empty())
1448 return success();
1449
1450 if (region.empty())
1451 return op->emitOpError() << "expects non-empty " << regionName << " region";
1452 Block &firstBlock = region.front();
1453 if (firstBlock.getNumArguments() < 1 ||
1454 firstBlock.getArgument(0).getType() != type)
1455 return op->emitOpError() << "expects " << regionName
1456 << " region first "
1457 "argument of the "
1458 << regionType << " type";
1459
1460 if (verifyYield) {
1461 for (YieldOp yieldOp : region.getOps<acc::YieldOp>()) {
1462 if (yieldOp.getOperands().size() != 1 ||
1463 yieldOp.getOperands().getTypes()[0] != type)
1464 return op->emitOpError() << "expects " << regionName
1465 << " region to "
1466 "yield a value of the "
1467 << regionType << " type";
1468 }
1469 }
1470 return success();
1471}
1472
1473LogicalResult acc::PrivateRecipeOp::verifyRegions() {
1474 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1475 "privatization", "init", getType(),
1476 /*verifyYield=*/false)))
1477 return failure();
1479 *this, getDestroyRegion(), "privatization", "destroy", getType(),
1480 /*verifyYield=*/false, /*optional=*/true)))
1481 return failure();
1482 return success();
1483}
1484
1485std::optional<PrivateRecipeOp>
1486PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1487 StringRef recipeName, Type varType,
1488 StringRef varName, ValueRange bounds) {
1489 // First, validate that we can handle this variable type
1490 bool isMappable = isa<MappableType>(varType);
1491 bool isPointerLike = isa<PointerLikeType>(varType);
1492
1493 // Unsupported type
1494 if (!isMappable && !isPointerLike)
1495 return std::nullopt;
1496
1497 OpBuilder::InsertionGuard guard(builder);
1498
1499 // Create the recipe operation first so regions have proper parent context
1500 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1501
1502 // Populate the init region
1503 bool needsFree = false;
1504 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1505 varName, bounds, needsFree))) {
1506 recipe.erase();
1507 return std::nullopt;
1508 }
1509
1510 // Only create destroy region if the allocation needs deallocation
1511 if (needsFree) {
1512 // Extract the allocated value from the init block's yield operation
1513 auto yieldOp =
1514 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1515 Value allocRes = yieldOp.getOperand(0);
1516
1517 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1518 varType, allocRes, bounds))) {
1519 recipe.erase();
1520 return std::nullopt;
1521 }
1522 }
1523
1524 return recipe;
1525}
1526
1527std::optional<PrivateRecipeOp>
1528PrivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1529 StringRef recipeName,
1530 FirstprivateRecipeOp firstprivRecipe) {
1531 // Create the private.recipe op with the same type as the firstprivate.recipe.
1532 OpBuilder::InsertionGuard guard(builder);
1533 auto varType = firstprivRecipe.getType();
1534 auto recipe = PrivateRecipeOp::create(builder, loc, recipeName, varType);
1535
1536 // Clone the init region
1537 IRMapping mapping;
1538 firstprivRecipe.getInitRegion().cloneInto(&recipe.getInitRegion(), mapping);
1539
1540 // Clone destroy region if the firstprivate.recipe has one.
1541 if (!firstprivRecipe.getDestroyRegion().empty()) {
1542 IRMapping mapping;
1543 firstprivRecipe.getDestroyRegion().cloneInto(&recipe.getDestroyRegion(),
1544 mapping);
1545 }
1546 return recipe;
1547}
1548
1549//===----------------------------------------------------------------------===//
1550// FirstprivateRecipeOp
1551//===----------------------------------------------------------------------===//
1552
1553LogicalResult acc::FirstprivateRecipeOp::verifyRegions() {
1554 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(),
1555 "privatization", "init", getType(),
1556 /*verifyYield=*/false)))
1557 return failure();
1558
1559 if (getCopyRegion().empty())
1560 return emitOpError() << "expects non-empty copy region";
1561
1562 Block &firstBlock = getCopyRegion().front();
1563 if (firstBlock.getNumArguments() < 2 ||
1564 firstBlock.getArgument(0).getType() != getType())
1565 return emitOpError() << "expects copy region with two arguments of the "
1566 "privatization type";
1567
1568 if (getDestroyRegion().empty())
1569 return success();
1570
1571 if (failed(verifyInitLikeSingleArgRegion(*this, getDestroyRegion(),
1572 "privatization", "destroy",
1573 getType(), /*verifyYield=*/false)))
1574 return failure();
1575
1576 return success();
1577}
1578
1579std::optional<FirstprivateRecipeOp>
1580FirstprivateRecipeOp::createAndPopulate(OpBuilder &builder, Location loc,
1581 StringRef recipeName, Type varType,
1582 StringRef varName, ValueRange bounds) {
1583 // First, validate that we can handle this variable type
1584 bool isMappable = isa<MappableType>(varType);
1585 bool isPointerLike = isa<PointerLikeType>(varType);
1586
1587 // Unsupported type
1588 if (!isMappable && !isPointerLike)
1589 return std::nullopt;
1590
1591 OpBuilder::InsertionGuard guard(builder);
1592
1593 // Create the recipe operation first so regions have proper parent context
1594 auto recipe = FirstprivateRecipeOp::create(builder, loc, recipeName, varType);
1595
1596 // Populate the init region
1597 bool needsFree = false;
1598 if (failed(createInitRegion(builder, loc, recipe.getInitRegion(), varType,
1599 varName, bounds, needsFree))) {
1600 recipe.erase();
1601 return std::nullopt;
1602 }
1603
1604 // Populate the copy region
1605 if (failed(createCopyRegion(builder, loc, recipe.getCopyRegion(), varType,
1606 bounds))) {
1607 recipe.erase();
1608 return std::nullopt;
1609 }
1610
1611 // Only create destroy region if the allocation needs deallocation
1612 if (needsFree) {
1613 // Extract the allocated value from the init block's yield operation
1614 auto yieldOp =
1615 cast<acc::YieldOp>(recipe.getInitRegion().front().getTerminator());
1616 Value allocRes = yieldOp.getOperand(0);
1617
1618 if (failed(createDestroyRegion(builder, loc, recipe.getDestroyRegion(),
1619 varType, allocRes, bounds))) {
1620 recipe.erase();
1621 return std::nullopt;
1622 }
1623 }
1624
1625 return recipe;
1626}
1627
1628//===----------------------------------------------------------------------===//
1629// ReductionRecipeOp
1630//===----------------------------------------------------------------------===//
1631
1632LogicalResult acc::ReductionRecipeOp::verifyRegions() {
1633 if (failed(verifyInitLikeSingleArgRegion(*this, getInitRegion(), "reduction",
1634 "init", getType(),
1635 /*verifyYield=*/false)))
1636 return failure();
1637
1638 if (getCombinerRegion().empty())
1639 return emitOpError() << "expects non-empty combiner region";
1640
1641 Block &reductionBlock = getCombinerRegion().front();
1642 if (reductionBlock.getNumArguments() < 2 ||
1643 reductionBlock.getArgument(0).getType() != getType() ||
1644 reductionBlock.getArgument(1).getType() != getType())
1645 return emitOpError() << "expects combiner region with the first two "
1646 << "arguments of the reduction type";
1647
1648 for (YieldOp yieldOp : getCombinerRegion().getOps<YieldOp>()) {
1649 if (yieldOp.getOperands().size() != 1 ||
1650 yieldOp.getOperands().getTypes()[0] != getType())
1651 return emitOpError() << "expects combiner region to yield a value "
1652 "of the reduction type";
1653 }
1654
1655 return success();
1656}
1657
1658//===----------------------------------------------------------------------===//
1659// ParallelOp
1660//===----------------------------------------------------------------------===//
1661
1662/// Check dataOperands for acc.parallel, acc.serial and acc.kernels.
1663template <typename Op>
1664static LogicalResult checkDataOperands(Op op,
1665 const mlir::ValueRange &operands) {
1666 for (mlir::Value operand : operands)
1667 if (!mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
1668 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
1669 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
1670 operand.getDefiningOp()))
1671 return op.emitError(
1672 "expect data entry/exit operation or acc.getdeviceptr "
1673 "as defining op");
1674 return success();
1675}
1676
1677template <typename OpT, typename RecipeOpT>
1678static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp,
1679 const mlir::ValueRange &operands,
1680 llvm::StringRef operandName) {
1682 for (mlir::Value operand : operands) {
1683 if (!mlir::isa<OpT>(operand.getDefiningOp()))
1684 return accConstructOp->emitOpError()
1685 << "expected " << operandName << " as defining op";
1686 if (!set.insert(operand).second)
1687 return accConstructOp->emitOpError()
1688 << operandName << " operand appears more than once";
1689 }
1690 return success();
1691}
1692
1693unsigned ParallelOp::getNumDataOperands() {
1694 return getReductionOperands().size() + getPrivateOperands().size() +
1695 getFirstprivateOperands().size() + getDataClauseOperands().size();
1696}
1697
1698Value ParallelOp::getDataOperand(unsigned i) {
1699 unsigned numOptional = getAsyncOperands().size();
1700 numOptional += getNumGangs().size();
1701 numOptional += getNumWorkers().size();
1702 numOptional += getVectorLength().size();
1703 numOptional += getIfCond() ? 1 : 0;
1704 numOptional += getSelfCond() ? 1 : 0;
1705 return getOperand(getWaitOperands().size() + numOptional + i);
1706}
1707
1708template <typename Op>
1709static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands,
1710 ArrayAttr deviceTypes,
1711 llvm::StringRef keyword) {
1712 if (!operands.empty() && deviceTypes.getValue().size() != operands.size())
1713 return op.emitOpError() << keyword << " operands count must match "
1714 << keyword << " device_type count";
1715 return success();
1716}
1717
1718template <typename Op>
1720 Op op, OperandRange operands, DenseI32ArrayAttr segments,
1721 ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment = 0) {
1722 std::size_t numOperandsInSegments = 0;
1723 std::size_t nbOfSegments = 0;
1724
1725 if (segments) {
1726 for (auto segCount : segments.asArrayRef()) {
1727 if (maxInSegment != 0 && segCount > maxInSegment)
1728 return op.emitOpError() << keyword << " expects a maximum of "
1729 << maxInSegment << " values per segment";
1730 numOperandsInSegments += segCount;
1731 ++nbOfSegments;
1732 }
1733 }
1734
1735 if ((numOperandsInSegments != operands.size()) ||
1736 (!deviceTypes && !operands.empty()))
1737 return op.emitOpError()
1738 << keyword << " operand count does not match count in segments";
1739 if (deviceTypes && deviceTypes.getValue().size() != nbOfSegments)
1740 return op.emitOpError()
1741 << keyword << " segment count does not match device_type count";
1742 return success();
1743}
1744
1745LogicalResult acc::ParallelOp::verify() {
1746 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
1747 mlir::acc::PrivateRecipeOp>(
1748 *this, getPrivateOperands(), "private")))
1749 return failure();
1750 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
1751 mlir::acc::FirstprivateRecipeOp>(
1752 *this, getFirstprivateOperands(), "firstprivate")))
1753 return failure();
1754 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
1755 mlir::acc::ReductionRecipeOp>(
1756 *this, getReductionOperands(), "reduction")))
1757 return failure();
1758
1760 *this, getNumGangs(), getNumGangsSegmentsAttr(),
1761 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
1762 return failure();
1763
1765 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
1766 getWaitOperandsDeviceTypeAttr(), "wait")))
1767 return failure();
1768
1769 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
1770 getNumWorkersDeviceTypeAttr(),
1771 "num_workers")))
1772 return failure();
1773
1774 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
1775 getVectorLengthDeviceTypeAttr(),
1776 "vector_length")))
1777 return failure();
1778
1780 getAsyncOperandsDeviceTypeAttr(),
1781 "async")))
1782 return failure();
1783
1785 return failure();
1786
1787 return checkDataOperands<acc::ParallelOp>(*this, getDataClauseOperands());
1788}
1789
1790static mlir::Value
1791getValueInDeviceTypeSegment(std::optional<mlir::ArrayAttr> arrayAttr,
1793 mlir::acc::DeviceType deviceType) {
1794 if (!arrayAttr)
1795 return {};
1796 if (auto pos = findSegment(*arrayAttr, deviceType))
1797 return range[*pos];
1798 return {};
1799}
1800
1801bool acc::ParallelOp::hasAsyncOnly() {
1802 return hasAsyncOnly(mlir::acc::DeviceType::None);
1803}
1804
1805bool acc::ParallelOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
1806 return hasDeviceType(getAsyncOnly(), deviceType);
1807}
1808
1809mlir::Value acc::ParallelOp::getAsyncValue() {
1810 return getAsyncValue(mlir::acc::DeviceType::None);
1811}
1812
1813mlir::Value acc::ParallelOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
1815 getAsyncOperands(), deviceType);
1816}
1817
1818mlir::Value acc::ParallelOp::getNumWorkersValue() {
1819 return getNumWorkersValue(mlir::acc::DeviceType::None);
1820}
1821
1823acc::ParallelOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
1824 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
1825 deviceType);
1826}
1827
1828mlir::Value acc::ParallelOp::getVectorLengthValue() {
1829 return getVectorLengthValue(mlir::acc::DeviceType::None);
1830}
1831
1833acc::ParallelOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
1834 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
1835 getVectorLength(), deviceType);
1836}
1837
1838mlir::Operation::operand_range ParallelOp::getNumGangsValues() {
1839 return getNumGangsValues(mlir::acc::DeviceType::None);
1840}
1841
1843ParallelOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
1844 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
1845 getNumGangsSegments(), deviceType);
1846}
1847
1848bool acc::ParallelOp::hasWaitOnly() {
1849 return hasWaitOnly(mlir::acc::DeviceType::None);
1850}
1851
1852bool acc::ParallelOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
1853 return hasDeviceType(getWaitOnly(), deviceType);
1854}
1855
1856mlir::Operation::operand_range ParallelOp::getWaitValues() {
1857 return getWaitValues(mlir::acc::DeviceType::None);
1858}
1859
1861ParallelOp::getWaitValues(mlir::acc::DeviceType deviceType) {
1863 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
1864 getHasWaitDevnum(), deviceType);
1865}
1866
1867mlir::Value ParallelOp::getWaitDevnum() {
1868 return getWaitDevnum(mlir::acc::DeviceType::None);
1869}
1870
1871mlir::Value ParallelOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
1872 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
1873 getWaitOperandsSegments(), getHasWaitDevnum(),
1874 deviceType);
1875}
1876
1877void ParallelOp::build(mlir::OpBuilder &odsBuilder,
1878 mlir::OperationState &odsState,
1879 mlir::ValueRange numGangs, mlir::ValueRange numWorkers,
1880 mlir::ValueRange vectorLength,
1881 mlir::ValueRange asyncOperands,
1882 mlir::ValueRange waitOperands, mlir::Value ifCond,
1883 mlir::Value selfCond, mlir::ValueRange reductionOperands,
1884 mlir::ValueRange gangPrivateOperands,
1885 mlir::ValueRange gangFirstPrivateOperands,
1886 mlir::ValueRange dataClauseOperands) {
1887 ParallelOp::build(
1888 odsBuilder, odsState, asyncOperands, /*asyncOperandsDeviceType=*/nullptr,
1889 /*asyncOnly=*/nullptr, waitOperands, /*waitOperandsSegments=*/nullptr,
1890 /*waitOperandsDeviceType=*/nullptr, /*hasWaitDevnum=*/nullptr,
1891 /*waitOnly=*/nullptr, numGangs, /*numGangsSegments=*/nullptr,
1892 /*numGangsDeviceType=*/nullptr, numWorkers,
1893 /*numWorkersDeviceType=*/nullptr, vectorLength,
1894 /*vectorLengthDeviceType=*/nullptr, ifCond, selfCond,
1895 /*selfAttr=*/nullptr, reductionOperands, gangPrivateOperands,
1896 gangFirstPrivateOperands, dataClauseOperands,
1897 /*defaultAttr=*/nullptr, /*combined=*/nullptr);
1898}
1899
1900void acc::ParallelOp::addNumWorkersOperand(
1901 MLIRContext *context, mlir::Value newValue,
1902 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1903 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1904 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1905 getNumWorkersMutable()));
1906}
1907void acc::ParallelOp::addVectorLengthOperand(
1908 MLIRContext *context, mlir::Value newValue,
1909 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1910 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1911 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1912 getVectorLengthMutable()));
1913}
1914
1915void acc::ParallelOp::addAsyncOnly(
1916 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1917 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
1918 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
1919}
1920
1921void acc::ParallelOp::addAsyncOperand(
1922 MLIRContext *context, mlir::Value newValue,
1923 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1924 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1925 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
1926 getAsyncOperandsMutable()));
1927}
1928
1929void acc::ParallelOp::addNumGangsOperands(
1930 MLIRContext *context, mlir::ValueRange newValues,
1931 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1933 if (getNumGangsSegments())
1934 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
1935
1936 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1937 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1938 getNumGangsMutable(), segments));
1939
1940 setNumGangsSegments(segments);
1941}
1942void acc::ParallelOp::addWaitOnly(
1943 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1944 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
1945 effectiveDeviceTypes));
1946}
1947void acc::ParallelOp::addWaitOperands(
1948 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
1949 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
1950
1952 if (getWaitOperandsSegments())
1953 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
1954
1955 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
1956 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
1957 getWaitOperandsMutable(), segments));
1958 setWaitOperandsSegments(segments);
1959
1961 if (getHasWaitDevnumAttr())
1962 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
1963 hasDevnums.insert(
1964 hasDevnums.end(),
1965 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
1966 mlir::BoolAttr::get(context, hasDevnum));
1967 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
1968}
1969
1970void acc::ParallelOp::addPrivatization(MLIRContext *context,
1971 mlir::acc::PrivateOp op,
1972 mlir::acc::PrivateRecipeOp recipe) {
1973 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1974 getPrivateOperandsMutable().append(op.getResult());
1975}
1976
1977void acc::ParallelOp::addFirstPrivatization(
1978 MLIRContext *context, mlir::acc::FirstprivateOp op,
1979 mlir::acc::FirstprivateRecipeOp recipe) {
1980 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1981 getFirstprivateOperandsMutable().append(op.getResult());
1982}
1983
1984void acc::ParallelOp::addReduction(MLIRContext *context,
1985 mlir::acc::ReductionOp op,
1986 mlir::acc::ReductionRecipeOp recipe) {
1987 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
1988 getReductionOperandsMutable().append(op.getResult());
1989}
1990
1991static ParseResult parseNumGangs(
1992 mlir::OpAsmParser &parser,
1994 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
1995 mlir::DenseI32ArrayAttr &segments) {
1998
1999 do {
2000 if (failed(parser.parseLBrace()))
2001 return failure();
2002
2003 int32_t crtOperandsSize = operands.size();
2004 if (failed(parser.parseCommaSeparatedList(
2006 if (parser.parseOperand(operands.emplace_back()) ||
2007 parser.parseColonType(types.emplace_back()))
2008 return failure();
2009 return success();
2010 })))
2011 return failure();
2012 seg.push_back(operands.size() - crtOperandsSize);
2013
2014 if (failed(parser.parseRBrace()))
2015 return failure();
2016
2017 if (succeeded(parser.parseOptionalLSquare())) {
2018 if (parser.parseAttribute(attributes.emplace_back()) ||
2019 parser.parseRSquare())
2020 return failure();
2021 } else {
2022 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2023 parser.getContext(), mlir::acc::DeviceType::None));
2024 }
2025 } while (succeeded(parser.parseOptionalComma()));
2026
2027 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2028 attributes.end());
2029 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2030 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2031
2032 return success();
2033}
2034
2036 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
2037 if (deviceTypeAttr.getValue() != mlir::acc::DeviceType::None)
2038 p << " [" << attr << "]";
2039}
2040
2042 mlir::OperandRange operands, mlir::TypeRange types,
2043 std::optional<mlir::ArrayAttr> deviceTypes,
2044 std::optional<mlir::DenseI32ArrayAttr> segments) {
2045 unsigned opIdx = 0;
2046 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2047 p << "{";
2048 llvm::interleaveComma(
2049 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2050 p << operands[opIdx] << " : " << operands[opIdx].getType();
2051 ++opIdx;
2052 });
2053 p << "}";
2054 printSingleDeviceType(p, it.value());
2055 });
2056}
2057
2059 mlir::OpAsmParser &parser,
2061 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2062 mlir::DenseI32ArrayAttr &segments) {
2065
2066 do {
2067 if (failed(parser.parseLBrace()))
2068 return failure();
2069
2070 int32_t crtOperandsSize = operands.size();
2071
2072 if (failed(parser.parseCommaSeparatedList(
2074 if (parser.parseOperand(operands.emplace_back()) ||
2075 parser.parseColonType(types.emplace_back()))
2076 return failure();
2077 return success();
2078 })))
2079 return failure();
2080
2081 seg.push_back(operands.size() - crtOperandsSize);
2082
2083 if (failed(parser.parseRBrace()))
2084 return failure();
2085
2086 if (succeeded(parser.parseOptionalLSquare())) {
2087 if (parser.parseAttribute(attributes.emplace_back()) ||
2088 parser.parseRSquare())
2089 return failure();
2090 } else {
2091 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2092 parser.getContext(), mlir::acc::DeviceType::None));
2093 }
2094 } while (succeeded(parser.parseOptionalComma()));
2095
2096 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2097 attributes.end());
2098 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2099 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2100
2101 return success();
2102}
2103
2106 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2107 std::optional<mlir::DenseI32ArrayAttr> segments) {
2108 unsigned opIdx = 0;
2109 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2110 p << "{";
2111 llvm::interleaveComma(
2112 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2113 p << operands[opIdx] << " : " << operands[opIdx].getType();
2114 ++opIdx;
2115 });
2116 p << "}";
2117 printSingleDeviceType(p, it.value());
2118 });
2119}
2120
2121static ParseResult parseWaitClause(
2122 mlir::OpAsmParser &parser,
2124 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2125 mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum,
2126 mlir::ArrayAttr &keywordOnly) {
2127 llvm::SmallVector<mlir::Attribute> deviceTypeAttrs, keywordAttrs, devnum;
2129
2130 bool needCommaBeforeOperands = false;
2131
2132 // Keyword only
2133 if (failed(parser.parseOptionalLParen())) {
2134 keywordAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2135 parser.getContext(), mlir::acc::DeviceType::None));
2136 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2137 return success();
2138 }
2139
2140 // Parse keyword only attributes
2141 if (succeeded(parser.parseOptionalLSquare())) {
2142 if (failed(parser.parseCommaSeparatedList([&]() {
2143 if (parser.parseAttribute(keywordAttrs.emplace_back()))
2144 return failure();
2145 return success();
2146 })))
2147 return failure();
2148 if (parser.parseRSquare())
2149 return failure();
2150 needCommaBeforeOperands = true;
2151 }
2152
2153 if (needCommaBeforeOperands && failed(parser.parseComma()))
2154 return failure();
2155
2156 do {
2157 if (failed(parser.parseLBrace()))
2158 return failure();
2159
2160 int32_t crtOperandsSize = operands.size();
2161
2162 if (succeeded(parser.parseOptionalKeyword("devnum"))) {
2163 if (failed(parser.parseColon()))
2164 return failure();
2165 devnum.push_back(BoolAttr::get(parser.getContext(), true));
2166 } else {
2167 devnum.push_back(BoolAttr::get(parser.getContext(), false));
2168 }
2169
2170 if (failed(parser.parseCommaSeparatedList(
2172 if (parser.parseOperand(operands.emplace_back()) ||
2173 parser.parseColonType(types.emplace_back()))
2174 return failure();
2175 return success();
2176 })))
2177 return failure();
2178
2179 seg.push_back(operands.size() - crtOperandsSize);
2180
2181 if (failed(parser.parseRBrace()))
2182 return failure();
2183
2184 if (succeeded(parser.parseOptionalLSquare())) {
2185 if (parser.parseAttribute(deviceTypeAttrs.emplace_back()) ||
2186 parser.parseRSquare())
2187 return failure();
2188 } else {
2189 deviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
2190 parser.getContext(), mlir::acc::DeviceType::None));
2191 }
2192 } while (succeeded(parser.parseOptionalComma()));
2193
2194 if (failed(parser.parseRParen()))
2195 return failure();
2196
2197 deviceTypes = ArrayAttr::get(parser.getContext(), deviceTypeAttrs);
2198 keywordOnly = ArrayAttr::get(parser.getContext(), keywordAttrs);
2199 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
2200 hasDevNum = ArrayAttr::get(parser.getContext(), devnum);
2201
2202 return success();
2203}
2204
2205static bool hasOnlyDeviceTypeNone(std::optional<mlir::ArrayAttr> attrs) {
2206 if (!hasDeviceTypeValues(attrs))
2207 return false;
2208 if (attrs->size() != 1)
2209 return false;
2210 if (auto deviceTypeAttr =
2211 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*attrs)[0]))
2212 return deviceTypeAttr.getValue() == mlir::acc::DeviceType::None;
2213 return false;
2214}
2215
2217 mlir::OperandRange operands, mlir::TypeRange types,
2218 std::optional<mlir::ArrayAttr> deviceTypes,
2219 std::optional<mlir::DenseI32ArrayAttr> segments,
2220 std::optional<mlir::ArrayAttr> hasDevNum,
2221 std::optional<mlir::ArrayAttr> keywordOnly) {
2222
2223 if (operands.begin() == operands.end() && hasOnlyDeviceTypeNone(keywordOnly))
2224 return;
2225
2226 p << "(";
2227
2228 printDeviceTypes(p, keywordOnly);
2229 if (hasDeviceTypeValues(keywordOnly) && hasDeviceTypeValues(deviceTypes))
2230 p << ", ";
2231
2232 if (hasDeviceTypeValues(deviceTypes)) {
2233 unsigned opIdx = 0;
2234 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
2235 p << "{";
2236 auto boolAttr = mlir::dyn_cast<mlir::BoolAttr>((*hasDevNum)[it.index()]);
2237 if (boolAttr && boolAttr.getValue())
2238 p << "devnum: ";
2239 llvm::interleaveComma(
2240 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
2241 p << operands[opIdx] << " : " << operands[opIdx].getType();
2242 ++opIdx;
2243 });
2244 p << "}";
2245 printSingleDeviceType(p, it.value());
2246 });
2247 }
2248
2249 p << ")";
2250}
2251
2252static ParseResult parseDeviceTypeOperands(
2253 mlir::OpAsmParser &parser,
2255 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes) {
2257 if (failed(parser.parseCommaSeparatedList([&]() {
2258 if (parser.parseOperand(operands.emplace_back()) ||
2259 parser.parseColonType(types.emplace_back()))
2260 return failure();
2261 if (succeeded(parser.parseOptionalLSquare())) {
2262 if (parser.parseAttribute(attributes.emplace_back()) ||
2263 parser.parseRSquare())
2264 return failure();
2265 } else {
2266 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2267 parser.getContext(), mlir::acc::DeviceType::None));
2268 }
2269 return success();
2270 })))
2271 return failure();
2272 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2273 attributes.end());
2274 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2275 return success();
2276}
2277
2278static void
2280 mlir::OperandRange operands, mlir::TypeRange types,
2281 std::optional<mlir::ArrayAttr> deviceTypes) {
2282 if (!hasDeviceTypeValues(deviceTypes))
2283 return;
2284 llvm::interleaveComma(llvm::zip(*deviceTypes, operands), p, [&](auto it) {
2285 p << std::get<1>(it) << " : " << std::get<1>(it).getType();
2286 printSingleDeviceType(p, std::get<0>(it));
2287 });
2288}
2289
2291 mlir::OpAsmParser &parser,
2293 llvm::SmallVectorImpl<Type> &types, mlir::ArrayAttr &deviceTypes,
2294 mlir::ArrayAttr &keywordOnlyDeviceType) {
2295
2296 llvm::SmallVector<mlir::Attribute> keywordOnlyDeviceTypeAttributes;
2297 bool needCommaBeforeOperands = false;
2298
2299 if (failed(parser.parseOptionalLParen())) {
2300 // Keyword only
2301 keywordOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2302 parser.getContext(), mlir::acc::DeviceType::None));
2303 keywordOnlyDeviceType =
2304 ArrayAttr::get(parser.getContext(), keywordOnlyDeviceTypeAttributes);
2305 return success();
2306 }
2307
2308 // Parse keyword only attributes
2309 if (succeeded(parser.parseOptionalLSquare())) {
2310 // Parse keyword only attributes
2311 if (failed(parser.parseCommaSeparatedList([&]() {
2312 if (parser.parseAttribute(
2313 keywordOnlyDeviceTypeAttributes.emplace_back()))
2314 return failure();
2315 return success();
2316 })))
2317 return failure();
2318 if (parser.parseRSquare())
2319 return failure();
2320 needCommaBeforeOperands = true;
2321 }
2322
2323 if (needCommaBeforeOperands && failed(parser.parseComma()))
2324 return failure();
2325
2327 if (failed(parser.parseCommaSeparatedList([&]() {
2328 if (parser.parseOperand(operands.emplace_back()) ||
2329 parser.parseColonType(types.emplace_back()))
2330 return failure();
2331 if (succeeded(parser.parseOptionalLSquare())) {
2332 if (parser.parseAttribute(attributes.emplace_back()) ||
2333 parser.parseRSquare())
2334 return failure();
2335 } else {
2336 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
2337 parser.getContext(), mlir::acc::DeviceType::None));
2338 }
2339 return success();
2340 })))
2341 return failure();
2342
2343 if (failed(parser.parseRParen()))
2344 return failure();
2345
2346 llvm::SmallVector<mlir::Attribute> arrayAttr(attributes.begin(),
2347 attributes.end());
2348 deviceTypes = ArrayAttr::get(parser.getContext(), arrayAttr);
2349 return success();
2350}
2351
2354 mlir::TypeRange types, std::optional<mlir::ArrayAttr> deviceTypes,
2355 std::optional<mlir::ArrayAttr> keywordOnlyDeviceTypes) {
2356
2357 if (operands.begin() == operands.end() &&
2358 hasOnlyDeviceTypeNone(keywordOnlyDeviceTypes)) {
2359 return;
2360 }
2361
2362 p << "(";
2363 printDeviceTypes(p, keywordOnlyDeviceTypes);
2364 if (hasDeviceTypeValues(keywordOnlyDeviceTypes) &&
2365 hasDeviceTypeValues(deviceTypes))
2366 p << ", ";
2367 printDeviceTypeOperands(p, op, operands, types, deviceTypes);
2368 p << ")";
2369}
2370
2372 mlir::OpAsmParser &parser,
2373 std::optional<OpAsmParser::UnresolvedOperand> &operand,
2374 mlir::Type &operandType, mlir::UnitAttr &attr) {
2375 // Keyword only
2376 if (failed(parser.parseOptionalLParen())) {
2377 attr = mlir::UnitAttr::get(parser.getContext());
2378 return success();
2379 }
2380
2382 if (failed(parser.parseOperand(op)))
2383 return failure();
2384 operand = op;
2385 if (failed(parser.parseColon()))
2386 return failure();
2387 if (failed(parser.parseType(operandType)))
2388 return failure();
2389 if (failed(parser.parseRParen()))
2390 return failure();
2391
2392 return success();
2393}
2394
2396 mlir::Operation *op,
2397 std::optional<mlir::Value> operand,
2398 mlir::Type operandType,
2399 mlir::UnitAttr attr) {
2400 if (attr)
2401 return;
2402
2403 p << "(";
2404 p.printOperand(*operand);
2405 p << " : ";
2406 p.printType(operandType);
2407 p << ")";
2408}
2409
2411 mlir::OpAsmParser &parser,
2413 llvm::SmallVectorImpl<Type> &types, mlir::UnitAttr &attr) {
2414 // Keyword only
2415 if (failed(parser.parseOptionalLParen())) {
2416 attr = mlir::UnitAttr::get(parser.getContext());
2417 return success();
2418 }
2419
2420 if (failed(parser.parseCommaSeparatedList([&]() {
2421 if (parser.parseOperand(operands.emplace_back()))
2422 return failure();
2423 return success();
2424 })))
2425 return failure();
2426 if (failed(parser.parseColon()))
2427 return failure();
2428 if (failed(parser.parseCommaSeparatedList([&]() {
2429 if (parser.parseType(types.emplace_back()))
2430 return failure();
2431 return success();
2432 })))
2433 return failure();
2434 if (failed(parser.parseRParen()))
2435 return failure();
2436
2437 return success();
2438}
2439
2441 mlir::Operation *op,
2442 mlir::OperandRange operands,
2443 mlir::TypeRange types,
2444 mlir::UnitAttr attr) {
2445 if (attr)
2446 return;
2447
2448 p << "(";
2449 llvm::interleaveComma(operands, p, [&](auto it) { p << it; });
2450 p << " : ";
2451 llvm::interleaveComma(types, p, [&](auto it) { p << it; });
2452 p << ")";
2453}
2454
2455static ParseResult
2457 mlir::acc::CombinedConstructsTypeAttr &attr) {
2458 if (succeeded(parser.parseOptionalKeyword("kernels"))) {
2459 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2460 parser.getContext(), mlir::acc::CombinedConstructsType::KernelsLoop);
2461 } else if (succeeded(parser.parseOptionalKeyword("parallel"))) {
2462 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2463 parser.getContext(), mlir::acc::CombinedConstructsType::ParallelLoop);
2464 } else if (succeeded(parser.parseOptionalKeyword("serial"))) {
2465 attr = mlir::acc::CombinedConstructsTypeAttr::get(
2466 parser.getContext(), mlir::acc::CombinedConstructsType::SerialLoop);
2467 } else {
2468 parser.emitError(parser.getCurrentLocation(),
2469 "expected compute construct name");
2470 return failure();
2471 }
2472 return success();
2473}
2474
2475static void
2477 mlir::acc::CombinedConstructsTypeAttr attr) {
2478 if (attr) {
2479 switch (attr.getValue()) {
2480 case mlir::acc::CombinedConstructsType::KernelsLoop:
2481 p << "kernels";
2482 break;
2483 case mlir::acc::CombinedConstructsType::ParallelLoop:
2484 p << "parallel";
2485 break;
2486 case mlir::acc::CombinedConstructsType::SerialLoop:
2487 p << "serial";
2488 break;
2489 };
2490 }
2491}
2492
2493//===----------------------------------------------------------------------===//
2494// SerialOp
2495//===----------------------------------------------------------------------===//
2496
2497unsigned SerialOp::getNumDataOperands() {
2498 return getReductionOperands().size() + getPrivateOperands().size() +
2499 getFirstprivateOperands().size() + getDataClauseOperands().size();
2500}
2501
2502Value SerialOp::getDataOperand(unsigned i) {
2503 unsigned numOptional = getAsyncOperands().size();
2504 numOptional += getIfCond() ? 1 : 0;
2505 numOptional += getSelfCond() ? 1 : 0;
2506 return getOperand(getWaitOperands().size() + numOptional + i);
2507}
2508
2509bool acc::SerialOp::hasAsyncOnly() {
2510 return hasAsyncOnly(mlir::acc::DeviceType::None);
2511}
2512
2513bool acc::SerialOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2514 return hasDeviceType(getAsyncOnly(), deviceType);
2515}
2516
2517mlir::Value acc::SerialOp::getAsyncValue() {
2518 return getAsyncValue(mlir::acc::DeviceType::None);
2519}
2520
2521mlir::Value acc::SerialOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2523 getAsyncOperands(), deviceType);
2524}
2525
2526bool acc::SerialOp::hasWaitOnly() {
2527 return hasWaitOnly(mlir::acc::DeviceType::None);
2528}
2529
2530bool acc::SerialOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2531 return hasDeviceType(getWaitOnly(), deviceType);
2532}
2533
2534mlir::Operation::operand_range SerialOp::getWaitValues() {
2535 return getWaitValues(mlir::acc::DeviceType::None);
2536}
2537
2539SerialOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2541 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2542 getHasWaitDevnum(), deviceType);
2543}
2544
2545mlir::Value SerialOp::getWaitDevnum() {
2546 return getWaitDevnum(mlir::acc::DeviceType::None);
2547}
2548
2549mlir::Value SerialOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2550 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2551 getWaitOperandsSegments(), getHasWaitDevnum(),
2552 deviceType);
2553}
2554
2555LogicalResult acc::SerialOp::verify() {
2556 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
2557 mlir::acc::PrivateRecipeOp>(
2558 *this, getPrivateOperands(), "private")))
2559 return failure();
2560 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
2561 mlir::acc::FirstprivateRecipeOp>(
2562 *this, getFirstprivateOperands(), "firstprivate")))
2563 return failure();
2564 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
2565 mlir::acc::ReductionRecipeOp>(
2566 *this, getReductionOperands(), "reduction")))
2567 return failure();
2568
2570 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2571 getWaitOperandsDeviceTypeAttr(), "wait")))
2572 return failure();
2573
2575 getAsyncOperandsDeviceTypeAttr(),
2576 "async")))
2577 return failure();
2578
2580 return failure();
2581
2582 return checkDataOperands<acc::SerialOp>(*this, getDataClauseOperands());
2583}
2584
2585void acc::SerialOp::addAsyncOnly(
2586 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2587 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2588 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2589}
2590
2591void acc::SerialOp::addAsyncOperand(
2592 MLIRContext *context, mlir::Value newValue,
2593 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2594 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2595 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2596 getAsyncOperandsMutable()));
2597}
2598
2599void acc::SerialOp::addWaitOnly(
2600 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2601 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2602 effectiveDeviceTypes));
2603}
2604void acc::SerialOp::addWaitOperands(
2605 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2606 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2607
2609 if (getWaitOperandsSegments())
2610 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2611
2612 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2613 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2614 getWaitOperandsMutable(), segments));
2615 setWaitOperandsSegments(segments);
2616
2618 if (getHasWaitDevnumAttr())
2619 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2620 hasDevnums.insert(
2621 hasDevnums.end(),
2622 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2623 mlir::BoolAttr::get(context, hasDevnum));
2624 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2625}
2626
2627void acc::SerialOp::addPrivatization(MLIRContext *context,
2628 mlir::acc::PrivateOp op,
2629 mlir::acc::PrivateRecipeOp recipe) {
2630 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2631 getPrivateOperandsMutable().append(op.getResult());
2632}
2633
2634void acc::SerialOp::addFirstPrivatization(
2635 MLIRContext *context, mlir::acc::FirstprivateOp op,
2636 mlir::acc::FirstprivateRecipeOp recipe) {
2637 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2638 getFirstprivateOperandsMutable().append(op.getResult());
2639}
2640
2641void acc::SerialOp::addReduction(MLIRContext *context,
2642 mlir::acc::ReductionOp op,
2643 mlir::acc::ReductionRecipeOp recipe) {
2644 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2645 getReductionOperandsMutable().append(op.getResult());
2646}
2647
2648//===----------------------------------------------------------------------===//
2649// KernelsOp
2650//===----------------------------------------------------------------------===//
2651
2652unsigned KernelsOp::getNumDataOperands() {
2653 return getDataClauseOperands().size();
2654}
2655
2656Value KernelsOp::getDataOperand(unsigned i) {
2657 unsigned numOptional = getAsyncOperands().size();
2658 numOptional += getWaitOperands().size();
2659 numOptional += getNumGangs().size();
2660 numOptional += getNumWorkers().size();
2661 numOptional += getVectorLength().size();
2662 numOptional += getIfCond() ? 1 : 0;
2663 numOptional += getSelfCond() ? 1 : 0;
2664 return getOperand(numOptional + i);
2665}
2666
2667bool acc::KernelsOp::hasAsyncOnly() {
2668 return hasAsyncOnly(mlir::acc::DeviceType::None);
2669}
2670
2671bool acc::KernelsOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
2672 return hasDeviceType(getAsyncOnly(), deviceType);
2673}
2674
2675mlir::Value acc::KernelsOp::getAsyncValue() {
2676 return getAsyncValue(mlir::acc::DeviceType::None);
2677}
2678
2679mlir::Value acc::KernelsOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
2681 getAsyncOperands(), deviceType);
2682}
2683
2684mlir::Value acc::KernelsOp::getNumWorkersValue() {
2685 return getNumWorkersValue(mlir::acc::DeviceType::None);
2686}
2687
2689acc::KernelsOp::getNumWorkersValue(mlir::acc::DeviceType deviceType) {
2690 return getValueInDeviceTypeSegment(getNumWorkersDeviceType(), getNumWorkers(),
2691 deviceType);
2692}
2693
2694mlir::Value acc::KernelsOp::getVectorLengthValue() {
2695 return getVectorLengthValue(mlir::acc::DeviceType::None);
2696}
2697
2699acc::KernelsOp::getVectorLengthValue(mlir::acc::DeviceType deviceType) {
2700 return getValueInDeviceTypeSegment(getVectorLengthDeviceType(),
2701 getVectorLength(), deviceType);
2702}
2703
2704mlir::Operation::operand_range KernelsOp::getNumGangsValues() {
2705 return getNumGangsValues(mlir::acc::DeviceType::None);
2706}
2707
2709KernelsOp::getNumGangsValues(mlir::acc::DeviceType deviceType) {
2710 return getValuesFromSegments(getNumGangsDeviceType(), getNumGangs(),
2711 getNumGangsSegments(), deviceType);
2712}
2713
2714bool acc::KernelsOp::hasWaitOnly() {
2715 return hasWaitOnly(mlir::acc::DeviceType::None);
2716}
2717
2718bool acc::KernelsOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
2719 return hasDeviceType(getWaitOnly(), deviceType);
2720}
2721
2722mlir::Operation::operand_range KernelsOp::getWaitValues() {
2723 return getWaitValues(mlir::acc::DeviceType::None);
2724}
2725
2727KernelsOp::getWaitValues(mlir::acc::DeviceType deviceType) {
2729 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
2730 getHasWaitDevnum(), deviceType);
2731}
2732
2733mlir::Value KernelsOp::getWaitDevnum() {
2734 return getWaitDevnum(mlir::acc::DeviceType::None);
2735}
2736
2737mlir::Value KernelsOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
2738 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
2739 getWaitOperandsSegments(), getHasWaitDevnum(),
2740 deviceType);
2741}
2742
2743LogicalResult acc::KernelsOp::verify() {
2745 *this, getNumGangs(), getNumGangsSegmentsAttr(),
2746 getNumGangsDeviceTypeAttr(), "num_gangs", 3)))
2747 return failure();
2748
2750 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
2751 getWaitOperandsDeviceTypeAttr(), "wait")))
2752 return failure();
2753
2754 if (failed(verifyDeviceTypeCountMatch(*this, getNumWorkers(),
2755 getNumWorkersDeviceTypeAttr(),
2756 "num_workers")))
2757 return failure();
2758
2759 if (failed(verifyDeviceTypeCountMatch(*this, getVectorLength(),
2760 getVectorLengthDeviceTypeAttr(),
2761 "vector_length")))
2762 return failure();
2763
2765 getAsyncOperandsDeviceTypeAttr(),
2766 "async")))
2767 return failure();
2768
2770 return failure();
2771
2772 return checkDataOperands<acc::KernelsOp>(*this, getDataClauseOperands());
2773}
2774
2775void acc::KernelsOp::addPrivatization(MLIRContext *context,
2776 mlir::acc::PrivateOp op,
2777 mlir::acc::PrivateRecipeOp recipe) {
2778 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2779 getPrivateOperandsMutable().append(op.getResult());
2780}
2781
2782void acc::KernelsOp::addFirstPrivatization(
2783 MLIRContext *context, mlir::acc::FirstprivateOp op,
2784 mlir::acc::FirstprivateRecipeOp recipe) {
2785 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2786 getFirstprivateOperandsMutable().append(op.getResult());
2787}
2788
2789void acc::KernelsOp::addReduction(MLIRContext *context,
2790 mlir::acc::ReductionOp op,
2791 mlir::acc::ReductionRecipeOp recipe) {
2792 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
2793 getReductionOperandsMutable().append(op.getResult());
2794}
2795
2796void acc::KernelsOp::addNumWorkersOperand(
2797 MLIRContext *context, mlir::Value newValue,
2798 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2799 setNumWorkersDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2800 context, getNumWorkersDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2801 getNumWorkersMutable()));
2802}
2803
2804void acc::KernelsOp::addVectorLengthOperand(
2805 MLIRContext *context, mlir::Value newValue,
2806 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2807 setVectorLengthDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2808 context, getVectorLengthDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2809 getVectorLengthMutable()));
2810}
2811void acc::KernelsOp::addAsyncOnly(
2812 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2813 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
2814 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
2815}
2816
2817void acc::KernelsOp::addAsyncOperand(
2818 MLIRContext *context, mlir::Value newValue,
2819 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2820 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2821 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
2822 getAsyncOperandsMutable()));
2823}
2824
2825void acc::KernelsOp::addNumGangsOperands(
2826 MLIRContext *context, mlir::ValueRange newValues,
2827 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2829 if (getNumGangsSegmentsAttr())
2830 llvm::copy(*getNumGangsSegments(), std::back_inserter(segments));
2831
2832 setNumGangsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2833 context, getNumGangsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2834 getNumGangsMutable(), segments));
2835
2836 setNumGangsSegments(segments);
2837}
2838
2839void acc::KernelsOp::addWaitOnly(
2840 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2841 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
2842 effectiveDeviceTypes));
2843}
2844void acc::KernelsOp::addWaitOperands(
2845 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
2846 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
2847
2849 if (getWaitOperandsSegments())
2850 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
2851
2852 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
2853 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
2854 getWaitOperandsMutable(), segments));
2855 setWaitOperandsSegments(segments);
2856
2858 if (getHasWaitDevnumAttr())
2859 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
2860 hasDevnums.insert(
2861 hasDevnums.end(),
2862 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
2863 mlir::BoolAttr::get(context, hasDevnum));
2864 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
2865}
2866
2867//===----------------------------------------------------------------------===//
2868// HostDataOp
2869//===----------------------------------------------------------------------===//
2870
2871LogicalResult acc::HostDataOp::verify() {
2872 if (getDataClauseOperands().empty())
2873 return emitError("at least one operand must appear on the host_data "
2874 "operation");
2875
2876 for (mlir::Value operand : getDataClauseOperands())
2877 if (!mlir::isa<acc::UseDeviceOp>(operand.getDefiningOp()))
2878 return emitError("expect data entry operation as defining op");
2879 return success();
2880}
2881
2882void acc::HostDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
2883 MLIRContext *context) {
2884 results.add<RemoveConstantIfConditionWithRegion<HostDataOp>>(context);
2885}
2886
2887//===----------------------------------------------------------------------===//
2888// KernelEnvironmentOp
2889//===----------------------------------------------------------------------===//
2890
2891void acc::KernelEnvironmentOp::getCanonicalizationPatterns(
2892 RewritePatternSet &results, MLIRContext *context) {
2893 results.add<RemoveEmptyKernelEnvironment>(context);
2894}
2895
2896//===----------------------------------------------------------------------===//
2897// LoopOp
2898//===----------------------------------------------------------------------===//
2899
2900static ParseResult parseGangValue(
2901 OpAsmParser &parser, llvm::StringRef keyword,
2904 llvm::SmallVector<GangArgTypeAttr> &attributes, GangArgTypeAttr gangArgType,
2905 bool &needCommaBetweenValues, bool &newValue) {
2906 if (succeeded(parser.parseOptionalKeyword(keyword))) {
2907 if (parser.parseEqual())
2908 return failure();
2909 if (parser.parseOperand(operands.emplace_back()) ||
2910 parser.parseColonType(types.emplace_back()))
2911 return failure();
2912 attributes.push_back(gangArgType);
2913 needCommaBetweenValues = true;
2914 newValue = true;
2915 }
2916 return success();
2917}
2918
2919static ParseResult parseGangClause(
2920 OpAsmParser &parser,
2922 llvm::SmallVectorImpl<Type> &gangOperandsType, mlir::ArrayAttr &gangArgType,
2923 mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments,
2924 mlir::ArrayAttr &gangOnlyDeviceType) {
2925 llvm::SmallVector<GangArgTypeAttr> gangArgTypeAttributes;
2926 llvm::SmallVector<mlir::Attribute> deviceTypeAttributes;
2927 llvm::SmallVector<mlir::Attribute> gangOnlyDeviceTypeAttributes;
2929 bool needCommaBetweenValues = false;
2930 bool needCommaBeforeOperands = false;
2931
2932 if (failed(parser.parseOptionalLParen())) {
2933 // Gang only keyword
2934 gangOnlyDeviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
2935 parser.getContext(), mlir::acc::DeviceType::None));
2936 gangOnlyDeviceType =
2937 ArrayAttr::get(parser.getContext(), gangOnlyDeviceTypeAttributes);
2938 return success();
2939 }
2940
2941 // Parse gang only attributes
2942 if (succeeded(parser.parseOptionalLSquare())) {
2943 // Parse gang only attributes
2944 if (failed(parser.parseCommaSeparatedList([&]() {
2945 if (parser.parseAttribute(
2946 gangOnlyDeviceTypeAttributes.emplace_back()))
2947 return failure();
2948 return success();
2949 })))
2950 return failure();
2951 if (parser.parseRSquare())
2952 return failure();
2953 needCommaBeforeOperands = true;
2954 }
2955
2956 auto argNum = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2957 mlir::acc::GangArgType::Num);
2958 auto argDim = mlir::acc::GangArgTypeAttr::get(parser.getContext(),
2959 mlir::acc::GangArgType::Dim);
2960 auto argStatic = mlir::acc::GangArgTypeAttr::get(
2961 parser.getContext(), mlir::acc::GangArgType::Static);
2962
2963 do {
2964 if (needCommaBeforeOperands) {
2965 needCommaBeforeOperands = false;
2966 continue;
2967 }
2968
2969 if (failed(parser.parseLBrace()))
2970 return failure();
2971
2972 int32_t crtOperandsSize = gangOperands.size();
2973 while (true) {
2974 bool newValue = false;
2975 bool needValue = false;
2976 if (needCommaBetweenValues) {
2977 if (succeeded(parser.parseOptionalComma()))
2978 needValue = true; // expect a new value after comma.
2979 else
2980 break;
2981 }
2982
2983 if (failed(parseGangValue(parser, LoopOp::getGangNumKeyword(),
2984 gangOperands, gangOperandsType,
2985 gangArgTypeAttributes, argNum,
2986 needCommaBetweenValues, newValue)))
2987 return failure();
2988 if (failed(parseGangValue(parser, LoopOp::getGangDimKeyword(),
2989 gangOperands, gangOperandsType,
2990 gangArgTypeAttributes, argDim,
2991 needCommaBetweenValues, newValue)))
2992 return failure();
2993 if (failed(parseGangValue(parser, LoopOp::getGangStaticKeyword(),
2994 gangOperands, gangOperandsType,
2995 gangArgTypeAttributes, argStatic,
2996 needCommaBetweenValues, newValue)))
2997 return failure();
2998
2999 if (!newValue && needValue) {
3000 parser.emitError(parser.getCurrentLocation(),
3001 "new value expected after comma");
3002 return failure();
3003 }
3004
3005 if (!newValue)
3006 break;
3007 }
3008
3009 if (gangOperands.empty())
3010 return parser.emitError(
3011 parser.getCurrentLocation(),
3012 "expect at least one of num, dim or static values");
3013
3014 if (failed(parser.parseRBrace()))
3015 return failure();
3016
3017 if (succeeded(parser.parseOptionalLSquare())) {
3018 if (parser.parseAttribute(deviceTypeAttributes.emplace_back()) ||
3019 parser.parseRSquare())
3020 return failure();
3021 } else {
3022 deviceTypeAttributes.push_back(mlir::acc::DeviceTypeAttr::get(
3023 parser.getContext(), mlir::acc::DeviceType::None));
3024 }
3025
3026 seg.push_back(gangOperands.size() - crtOperandsSize);
3027
3028 } while (succeeded(parser.parseOptionalComma()));
3029
3030 if (failed(parser.parseRParen()))
3031 return failure();
3032
3033 llvm::SmallVector<mlir::Attribute> arrayAttr(gangArgTypeAttributes.begin(),
3034 gangArgTypeAttributes.end());
3035 gangArgType = ArrayAttr::get(parser.getContext(), arrayAttr);
3036 deviceType = ArrayAttr::get(parser.getContext(), deviceTypeAttributes);
3037
3039 gangOnlyDeviceTypeAttributes.begin(), gangOnlyDeviceTypeAttributes.end());
3040 gangOnlyDeviceType = ArrayAttr::get(parser.getContext(), gangOnlyAttr);
3041
3042 segments = DenseI32ArrayAttr::get(parser.getContext(), seg);
3043 return success();
3044}
3045
3047 mlir::OperandRange operands, mlir::TypeRange types,
3048 std::optional<mlir::ArrayAttr> gangArgTypes,
3049 std::optional<mlir::ArrayAttr> deviceTypes,
3050 std::optional<mlir::DenseI32ArrayAttr> segments,
3051 std::optional<mlir::ArrayAttr> gangOnlyDeviceTypes) {
3052
3053 if (operands.begin() == operands.end() &&
3054 hasOnlyDeviceTypeNone(gangOnlyDeviceTypes)) {
3055 return;
3056 }
3057
3058 p << "(";
3059
3060 printDeviceTypes(p, gangOnlyDeviceTypes);
3061
3062 if (hasDeviceTypeValues(gangOnlyDeviceTypes) &&
3063 hasDeviceTypeValues(deviceTypes))
3064 p << ", ";
3065
3066 if (hasDeviceTypeValues(deviceTypes)) {
3067 unsigned opIdx = 0;
3068 llvm::interleaveComma(llvm::enumerate(*deviceTypes), p, [&](auto it) {
3069 p << "{";
3070 llvm::interleaveComma(
3071 llvm::seq<int32_t>(0, (*segments)[it.index()]), p, [&](auto it) {
3072 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3073 (*gangArgTypes)[opIdx]);
3074 if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Num)
3075 p << LoopOp::getGangNumKeyword();
3076 else if (gangArgTypeAttr.getValue() == mlir::acc::GangArgType::Dim)
3077 p << LoopOp::getGangDimKeyword();
3078 else if (gangArgTypeAttr.getValue() ==
3079 mlir::acc::GangArgType::Static)
3080 p << LoopOp::getGangStaticKeyword();
3081 p << "=" << operands[opIdx] << " : " << operands[opIdx].getType();
3082 ++opIdx;
3083 });
3084 p << "}";
3085 printSingleDeviceType(p, it.value());
3086 });
3087 }
3088 p << ")";
3089}
3090
3092 std::optional<mlir::ArrayAttr> segments,
3093 llvm::SmallSet<mlir::acc::DeviceType, 3> &deviceTypes) {
3094 if (!segments)
3095 return false;
3096 for (auto attr : *segments) {
3097 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3098 if (!deviceTypes.insert(deviceTypeAttr.getValue()).second)
3099 return true;
3100 }
3101 return false;
3102}
3103
3104/// Check for duplicates in the DeviceType array attribute.
3105/// Returns std::nullopt if no duplicates, or the duplicate DeviceType if found.
3106static std::optional<mlir::acc::DeviceType>
3107checkDeviceTypes(mlir::ArrayAttr deviceTypes) {
3108 llvm::SmallSet<mlir::acc::DeviceType, 3> crtDeviceTypes;
3109 if (!deviceTypes)
3110 return std::nullopt;
3111 for (auto attr : deviceTypes) {
3112 auto deviceTypeAttr =
3113 mlir::dyn_cast_or_null<mlir::acc::DeviceTypeAttr>(attr);
3114 if (!deviceTypeAttr)
3115 return mlir::acc::DeviceType::None;
3116 if (!crtDeviceTypes.insert(deviceTypeAttr.getValue()).second)
3117 return deviceTypeAttr.getValue();
3118 }
3119 return std::nullopt;
3120}
3121
3122LogicalResult acc::LoopOp::verify() {
3123 if (getUpperbound().size() != getStep().size())
3124 return emitError() << "number of upperbounds expected to be the same as "
3125 "number of steps";
3126
3127 if (getUpperbound().size() != getLowerbound().size())
3128 return emitError() << "number of upperbounds expected to be the same as "
3129 "number of lowerbounds";
3130
3131 if (!getUpperbound().empty() && getInclusiveUpperbound() &&
3132 (getUpperbound().size() != getInclusiveUpperbound()->size()))
3133 return emitError() << "inclusiveUpperbound size is expected to be the same"
3134 << " as upperbound size";
3135
3136 // Check collapse
3137 if (getCollapseAttr() && !getCollapseDeviceTypeAttr())
3138 return emitOpError() << "collapse device_type attr must be define when"
3139 << " collapse attr is present";
3140
3141 if (getCollapseAttr() && getCollapseDeviceTypeAttr() &&
3142 getCollapseAttr().getValue().size() !=
3143 getCollapseDeviceTypeAttr().getValue().size())
3144 return emitOpError() << "collapse attribute count must match collapse"
3145 << " device_type count";
3146 if (auto duplicateDeviceType = checkDeviceTypes(getCollapseDeviceTypeAttr()))
3147 return emitOpError() << "duplicate device_type `"
3148 << acc::stringifyDeviceType(*duplicateDeviceType)
3149 << "` found in collapseDeviceType attribute";
3150
3151 // Check gang
3152 if (!getGangOperands().empty()) {
3153 if (!getGangOperandsArgType())
3154 return emitOpError() << "gangOperandsArgType attribute must be defined"
3155 << " when gang operands are present";
3156
3157 if (getGangOperands().size() !=
3158 getGangOperandsArgTypeAttr().getValue().size())
3159 return emitOpError() << "gangOperandsArgType attribute count must match"
3160 << " gangOperands count";
3161 }
3162 if (getGangAttr()) {
3163 if (auto duplicateDeviceType = checkDeviceTypes(getGangAttr()))
3164 return emitOpError() << "duplicate device_type `"
3165 << acc::stringifyDeviceType(*duplicateDeviceType)
3166 << "` found in gang attribute";
3167 }
3168
3170 *this, getGangOperands(), getGangOperandsSegmentsAttr(),
3171 getGangOperandsDeviceTypeAttr(), "gang")))
3172 return failure();
3173
3174 // Check worker
3175 if (auto duplicateDeviceType = checkDeviceTypes(getWorkerAttr()))
3176 return emitOpError() << "duplicate device_type `"
3177 << acc::stringifyDeviceType(*duplicateDeviceType)
3178 << "` found in worker attribute";
3179 if (auto duplicateDeviceType =
3180 checkDeviceTypes(getWorkerNumOperandsDeviceTypeAttr()))
3181 return emitOpError() << "duplicate device_type `"
3182 << acc::stringifyDeviceType(*duplicateDeviceType)
3183 << "` found in workerNumOperandsDeviceType attribute";
3184 if (failed(verifyDeviceTypeCountMatch(*this, getWorkerNumOperands(),
3185 getWorkerNumOperandsDeviceTypeAttr(),
3186 "worker")))
3187 return failure();
3188
3189 // Check vector
3190 if (auto duplicateDeviceType = checkDeviceTypes(getVectorAttr()))
3191 return emitOpError() << "duplicate device_type `"
3192 << acc::stringifyDeviceType(*duplicateDeviceType)
3193 << "` found in vector attribute";
3194 if (auto duplicateDeviceType =
3195 checkDeviceTypes(getVectorOperandsDeviceTypeAttr()))
3196 return emitOpError() << "duplicate device_type `"
3197 << acc::stringifyDeviceType(*duplicateDeviceType)
3198 << "` found in vectorOperandsDeviceType attribute";
3199 if (failed(verifyDeviceTypeCountMatch(*this, getVectorOperands(),
3200 getVectorOperandsDeviceTypeAttr(),
3201 "vector")))
3202 return failure();
3203
3205 *this, getTileOperands(), getTileOperandsSegmentsAttr(),
3206 getTileOperandsDeviceTypeAttr(), "tile")))
3207 return failure();
3208
3209 // auto, independent and seq attribute are mutually exclusive.
3210 llvm::SmallSet<mlir::acc::DeviceType, 3> deviceTypes;
3211 if (hasDuplicateDeviceTypes(getAuto_(), deviceTypes) ||
3212 hasDuplicateDeviceTypes(getIndependent(), deviceTypes) ||
3213 hasDuplicateDeviceTypes(getSeq(), deviceTypes)) {
3214 return emitError() << "only one of auto, independent, seq can be present "
3215 "at the same time";
3216 }
3217
3218 // Check that at least one of auto, independent, or seq is present
3219 // for the device-independent default clauses.
3220 auto hasDeviceNone = [](mlir::acc::DeviceTypeAttr attr) -> bool {
3221 return attr.getValue() == mlir::acc::DeviceType::None;
3222 };
3223 bool hasDefaultSeq =
3224 getSeqAttr()
3225 ? llvm::any_of(getSeqAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3226 hasDeviceNone)
3227 : false;
3228 bool hasDefaultIndependent =
3229 getIndependentAttr()
3230 ? llvm::any_of(
3231 getIndependentAttr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3232 hasDeviceNone)
3233 : false;
3234 bool hasDefaultAuto =
3235 getAuto_Attr()
3236 ? llvm::any_of(getAuto_Attr().getAsRange<mlir::acc::DeviceTypeAttr>(),
3237 hasDeviceNone)
3238 : false;
3239 if (!hasDefaultSeq && !hasDefaultIndependent && !hasDefaultAuto) {
3240 return emitError()
3241 << "at least one of auto, independent, seq must be present";
3242 }
3243
3244 // Gang, worker and vector are incompatible with seq.
3245 if (getSeqAttr()) {
3246 for (auto attr : getSeqAttr()) {
3247 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
3248 if (hasVector(deviceTypeAttr.getValue()) ||
3249 getVectorValue(deviceTypeAttr.getValue()) ||
3250 hasWorker(deviceTypeAttr.getValue()) ||
3251 getWorkerValue(deviceTypeAttr.getValue()) ||
3252 hasGang(deviceTypeAttr.getValue()) ||
3253 getGangValue(mlir::acc::GangArgType::Num,
3254 deviceTypeAttr.getValue()) ||
3255 getGangValue(mlir::acc::GangArgType::Dim,
3256 deviceTypeAttr.getValue()) ||
3257 getGangValue(mlir::acc::GangArgType::Static,
3258 deviceTypeAttr.getValue()))
3259 return emitError() << "gang, worker or vector cannot appear with seq";
3260 }
3261 }
3262
3263 if (failed(checkPrivateOperands<mlir::acc::PrivateOp,
3264 mlir::acc::PrivateRecipeOp>(
3265 *this, getPrivateOperands(), "private")))
3266 return failure();
3267
3268 if (failed(checkPrivateOperands<mlir::acc::FirstprivateOp,
3269 mlir::acc::FirstprivateRecipeOp>(
3270 *this, getFirstprivateOperands(), "firstprivate")))
3271 return failure();
3272
3273 if (failed(checkPrivateOperands<mlir::acc::ReductionOp,
3274 mlir::acc::ReductionRecipeOp>(
3275 *this, getReductionOperands(), "reduction")))
3276 return failure();
3277
3278 if (getCombined().has_value() &&
3279 (getCombined().value() != acc::CombinedConstructsType::ParallelLoop &&
3280 getCombined().value() != acc::CombinedConstructsType::KernelsLoop &&
3281 getCombined().value() != acc::CombinedConstructsType::SerialLoop)) {
3282 return emitError("unexpected combined constructs attribute");
3283 }
3284
3285 // Check non-empty body().
3286 if (getRegion().empty())
3287 return emitError("expected non-empty body.");
3288
3289 if (getUnstructured()) {
3290 if (!isContainerLike())
3291 return emitError(
3292 "unstructured acc.loop must not have induction variables");
3293 } else if (isContainerLike()) {
3294 // When it is container-like - it is expected to hold a loop-like operation.
3295 // Obtain the maximum collapse count - we use this to check that there
3296 // are enough loops contained.
3297 uint64_t collapseCount = getCollapseValue().value_or(1);
3298 if (getCollapseAttr()) {
3299 for (auto collapseEntry : getCollapseAttr()) {
3300 auto intAttr = mlir::dyn_cast<IntegerAttr>(collapseEntry);
3301 if (intAttr.getValue().getZExtValue() > collapseCount)
3302 collapseCount = intAttr.getValue().getZExtValue();
3303 }
3304 }
3305
3306 // We want to check that we find enough loop-like operations inside.
3307 // PreOrder walk allows us to walk in a breadth-first manner at each nesting
3308 // level.
3309 mlir::Operation *expectedParent = this->getOperation();
3310 bool foundSibling = false;
3311 getRegion().walk<WalkOrder::PreOrder>([&](mlir::Operation *op) {
3312 if (mlir::isa<mlir::LoopLikeOpInterface>(op)) {
3313 // This effectively checks that we are not looking at a sibling loop.
3314 if (op->getParentOfType<mlir::LoopLikeOpInterface>() !=
3315 expectedParent) {
3316 foundSibling = true;
3318 }
3319
3320 collapseCount--;
3321 expectedParent = op;
3322 }
3323 // We found enough contained loops.
3324 if (collapseCount == 0)
3327 });
3328
3329 if (foundSibling)
3330 return emitError("found sibling loops inside container-like acc.loop");
3331 if (collapseCount != 0)
3332 return emitError("failed to find enough loop-like operations inside "
3333 "container-like acc.loop");
3334 }
3335
3336 return success();
3337}
3338
3339unsigned LoopOp::getNumDataOperands() {
3340 return getReductionOperands().size() + getPrivateOperands().size() +
3341 getFirstprivateOperands().size();
3342}
3343
3344Value LoopOp::getDataOperand(unsigned i) {
3345 unsigned numOptional =
3346 getLowerbound().size() + getUpperbound().size() + getStep().size();
3347 numOptional += getGangOperands().size();
3348 numOptional += getVectorOperands().size();
3349 numOptional += getWorkerNumOperands().size();
3350 numOptional += getTileOperands().size();
3351 numOptional += getCacheOperands().size();
3352 return getOperand(numOptional + i);
3353}
3354
3355bool LoopOp::hasAuto() { return hasAuto(mlir::acc::DeviceType::None); }
3356
3357bool LoopOp::hasAuto(mlir::acc::DeviceType deviceType) {
3358 return hasDeviceType(getAuto_(), deviceType);
3359}
3360
3361bool LoopOp::hasIndependent() {
3362 return hasIndependent(mlir::acc::DeviceType::None);
3363}
3364
3365bool LoopOp::hasIndependent(mlir::acc::DeviceType deviceType) {
3366 return hasDeviceType(getIndependent(), deviceType);
3367}
3368
3369bool LoopOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
3370
3371bool LoopOp::hasSeq(mlir::acc::DeviceType deviceType) {
3372 return hasDeviceType(getSeq(), deviceType);
3373}
3374
3375mlir::Value LoopOp::getVectorValue() {
3376 return getVectorValue(mlir::acc::DeviceType::None);
3377}
3378
3379mlir::Value LoopOp::getVectorValue(mlir::acc::DeviceType deviceType) {
3380 return getValueInDeviceTypeSegment(getVectorOperandsDeviceType(),
3381 getVectorOperands(), deviceType);
3382}
3383
3384bool LoopOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
3385
3386bool LoopOp::hasVector(mlir::acc::DeviceType deviceType) {
3387 return hasDeviceType(getVector(), deviceType);
3388}
3389
3390mlir::Value LoopOp::getWorkerValue() {
3391 return getWorkerValue(mlir::acc::DeviceType::None);
3392}
3393
3394mlir::Value LoopOp::getWorkerValue(mlir::acc::DeviceType deviceType) {
3395 return getValueInDeviceTypeSegment(getWorkerNumOperandsDeviceType(),
3396 getWorkerNumOperands(), deviceType);
3397}
3398
3399bool LoopOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
3400
3401bool LoopOp::hasWorker(mlir::acc::DeviceType deviceType) {
3402 return hasDeviceType(getWorker(), deviceType);
3403}
3404
3405mlir::Operation::operand_range LoopOp::getTileValues() {
3406 return getTileValues(mlir::acc::DeviceType::None);
3407}
3408
3410LoopOp::getTileValues(mlir::acc::DeviceType deviceType) {
3411 return getValuesFromSegments(getTileOperandsDeviceType(), getTileOperands(),
3412 getTileOperandsSegments(), deviceType);
3413}
3414
3415std::optional<int64_t> LoopOp::getCollapseValue() {
3416 return getCollapseValue(mlir::acc::DeviceType::None);
3417}
3418
3419std::optional<int64_t>
3420LoopOp::getCollapseValue(mlir::acc::DeviceType deviceType) {
3421 if (!getCollapseAttr())
3422 return std::nullopt;
3423 if (auto pos = findSegment(getCollapseDeviceTypeAttr(), deviceType)) {
3424 auto intAttr =
3425 mlir::dyn_cast<IntegerAttr>(getCollapseAttr().getValue()[*pos]);
3426 return intAttr.getValue().getZExtValue();
3427 }
3428 return std::nullopt;
3429}
3430
3431mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType) {
3432 return getGangValue(gangArgType, mlir::acc::DeviceType::None);
3433}
3434
3435mlir::Value LoopOp::getGangValue(mlir::acc::GangArgType gangArgType,
3436 mlir::acc::DeviceType deviceType) {
3437 if (getGangOperands().empty())
3438 return {};
3439 if (auto pos = findSegment(*getGangOperandsDeviceType(), deviceType)) {
3440 int32_t nbOperandsBefore = 0;
3441 for (unsigned i = 0; i < *pos; ++i)
3442 nbOperandsBefore += (*getGangOperandsSegments())[i];
3444 getGangOperands()
3445 .drop_front(nbOperandsBefore)
3446 .take_front((*getGangOperandsSegments())[*pos]);
3447
3448 int32_t argTypeIdx = nbOperandsBefore;
3449 for (auto value : values) {
3450 auto gangArgTypeAttr = mlir::dyn_cast<mlir::acc::GangArgTypeAttr>(
3451 (*getGangOperandsArgType())[argTypeIdx]);
3452 if (gangArgTypeAttr.getValue() == gangArgType)
3453 return value;
3454 ++argTypeIdx;
3455 }
3456 }
3457 return {};
3458}
3459
3460bool LoopOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
3461
3462bool LoopOp::hasGang(mlir::acc::DeviceType deviceType) {
3463 return hasDeviceType(getGang(), deviceType);
3464}
3465
3466llvm::SmallVector<mlir::Region *> acc::LoopOp::getLoopRegions() {
3467 return {&getRegion()};
3468}
3469
3470/// loop-control ::= `control` `(` ssa-id-and-type-list `)` `=`
3471/// `(` ssa-id-and-type-list `)` `to` `(` ssa-id-and-type-list `)` `step`
3472/// `(` ssa-id-and-type-list `)`
3473/// region
3474ParseResult
3477 SmallVectorImpl<Type> &lowerboundType,
3479 SmallVectorImpl<Type> &upperboundType,
3481 SmallVectorImpl<Type> &stepType) {
3482
3484 if (succeeded(
3485 parser.parseOptionalKeyword(acc::LoopOp::getControlKeyword()))) {
3486 if (parser.parseLParen() ||
3487 parser.parseArgumentList(inductionVars, OpAsmParser::Delimiter::None,
3488 /*allowType=*/true) ||
3489 parser.parseRParen() || parser.parseEqual() || parser.parseLParen() ||
3490 parser.parseOperandList(lowerbound, inductionVars.size(),
3492 parser.parseColonTypeList(lowerboundType) || parser.parseRParen() ||
3493 parser.parseKeyword("to") || parser.parseLParen() ||
3494 parser.parseOperandList(upperbound, inductionVars.size(),
3496 parser.parseColonTypeList(upperboundType) || parser.parseRParen() ||
3497 parser.parseKeyword("step") || parser.parseLParen() ||
3498 parser.parseOperandList(step, inductionVars.size(),
3500 parser.parseColonTypeList(stepType) || parser.parseRParen())
3501 return failure();
3502 }
3503 return parser.parseRegion(region, inductionVars);
3504}
3505
3507 ValueRange lowerbound, TypeRange lowerboundType,
3508 ValueRange upperbound, TypeRange upperboundType,
3509 ValueRange steps, TypeRange stepType) {
3510 ValueRange regionArgs = region.front().getArguments();
3511 if (!regionArgs.empty()) {
3512 p << acc::LoopOp::getControlKeyword() << "(";
3513 llvm::interleaveComma(regionArgs, p,
3514 [&p](Value v) { p << v << " : " << v.getType(); });
3515 p << ") = (" << lowerbound << " : " << lowerboundType << ") to ("
3516 << upperbound << " : " << upperboundType << ") " << " step (" << steps
3517 << " : " << stepType << ") ";
3518 }
3519 p.printRegion(region, /*printEntryBlockArgs=*/false);
3520}
3521
3522void acc::LoopOp::addSeq(MLIRContext *context,
3523 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3524 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
3525 effectiveDeviceTypes));
3526}
3527
3528void acc::LoopOp::addIndependent(
3529 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3530 setIndependentAttr(addDeviceTypeAffectedOperandHelper(
3531 context, getIndependentAttr(), effectiveDeviceTypes));
3532}
3533
3534void acc::LoopOp::addAuto(MLIRContext *context,
3535 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3536 setAuto_Attr(addDeviceTypeAffectedOperandHelper(context, getAuto_Attr(),
3537 effectiveDeviceTypes));
3538}
3539
3540void acc::LoopOp::setCollapseForDeviceTypes(
3541 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3542 llvm::APInt value) {
3545
3546 assert((getCollapseAttr() == nullptr) ==
3547 (getCollapseDeviceTypeAttr() == nullptr));
3548 assert(value.getBitWidth() == 64);
3549
3550 if (getCollapseAttr()) {
3551 for (const auto &existing :
3552 llvm::zip_equal(getCollapseAttr(), getCollapseDeviceTypeAttr())) {
3553 newValues.push_back(std::get<0>(existing));
3554 newDeviceTypes.push_back(std::get<1>(existing));
3555 }
3556 }
3557
3558 if (effectiveDeviceTypes.empty()) {
3559 // If the effective device-types list is empty, this is before there are any
3560 // being applied by device_type, so this should be added as a 'none'.
3561 newValues.push_back(
3562 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3563 newDeviceTypes.push_back(
3564 acc::DeviceTypeAttr::get(context, DeviceType::None));
3565 } else {
3566 for (DeviceType dt : effectiveDeviceTypes) {
3567 newValues.push_back(
3568 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), value));
3569 newDeviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
3570 }
3571 }
3572
3573 setCollapseAttr(ArrayAttr::get(context, newValues));
3574 setCollapseDeviceTypeAttr(ArrayAttr::get(context, newDeviceTypes));
3575}
3576
3577void acc::LoopOp::setTileForDeviceTypes(
3578 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3579 ValueRange values) {
3581 if (getTileOperandsSegments())
3582 llvm::copy(*getTileOperandsSegments(), std::back_inserter(segments));
3583
3584 setTileOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3585 context, getTileOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3586 getTileOperandsMutable(), segments));
3587
3588 setTileOperandsSegments(segments);
3589}
3590
3591void acc::LoopOp::addVectorOperand(
3592 MLIRContext *context, mlir::Value newValue,
3593 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3594 setVectorOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3595 context, getVectorOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3596 newValue, getVectorOperandsMutable()));
3597}
3598
3599void acc::LoopOp::addEmptyVector(
3600 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3601 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
3602 effectiveDeviceTypes));
3603}
3604
3605void acc::LoopOp::addWorkerNumOperand(
3606 MLIRContext *context, mlir::Value newValue,
3607 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3608 setWorkerNumOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3609 context, getWorkerNumOperandsDeviceTypeAttr(), effectiveDeviceTypes,
3610 newValue, getWorkerNumOperandsMutable()));
3611}
3612
3613void acc::LoopOp::addEmptyWorker(
3614 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3615 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
3616 effectiveDeviceTypes));
3617}
3618
3619void acc::LoopOp::addEmptyGang(
3620 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3621 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
3622 effectiveDeviceTypes));
3623}
3624
3625bool acc::LoopOp::hasParallelismFlag(DeviceType dt) {
3626 auto hasDevice = [=](DeviceTypeAttr attr) -> bool {
3627 return attr.getValue() == dt;
3628 };
3629 auto testFromArr = [=](ArrayAttr arr) -> bool {
3630 return llvm::any_of(arr.getAsRange<DeviceTypeAttr>(), hasDevice);
3631 };
3632
3633 if (ArrayAttr arr = getSeqAttr(); arr && testFromArr(arr))
3634 return true;
3635 if (ArrayAttr arr = getIndependentAttr(); arr && testFromArr(arr))
3636 return true;
3637 if (ArrayAttr arr = getAuto_Attr(); arr && testFromArr(arr))
3638 return true;
3639
3640 return false;
3641}
3642
3643bool acc::LoopOp::hasDefaultGangWorkerVector() {
3644 return hasVector() || getVectorValue() || hasWorker() || getWorkerValue() ||
3645 hasGang() || getGangValue(GangArgType::Num) ||
3646 getGangValue(GangArgType::Dim) || getGangValue(GangArgType::Static);
3647}
3648
3649acc::LoopParMode
3650acc::LoopOp::getDefaultOrDeviceTypeParallelism(DeviceType deviceType) {
3651 if (hasSeq(deviceType))
3652 return LoopParMode::loop_seq;
3653 if (hasAuto(deviceType))
3654 return LoopParMode::loop_auto;
3655 if (hasIndependent(deviceType))
3656 return LoopParMode::loop_independent;
3657 if (hasSeq())
3658 return LoopParMode::loop_seq;
3659 if (hasAuto())
3660 return LoopParMode::loop_auto;
3661 assert(hasIndependent() &&
3662 "loop must have default auto, seq, or independent");
3663 return LoopParMode::loop_independent;
3664}
3665
3666void acc::LoopOp::addGangOperands(
3667 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
3670 if (std::optional<ArrayRef<int32_t>> existingSegments =
3671 getGangOperandsSegments())
3672 llvm::copy(*existingSegments, std::back_inserter(segments));
3673
3674 unsigned beforeCount = segments.size();
3675
3676 setGangOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3677 context, getGangOperandsDeviceTypeAttr(), effectiveDeviceTypes, values,
3678 getGangOperandsMutable(), segments));
3679
3680 setGangOperandsSegments(segments);
3681
3682 // This is a bit of extra work to make sure we update the 'types' correctly by
3683 // adding to the types collection the correct number of times. We could
3684 // potentially add something similar to the
3685 // addDeviceTypeAffectedOperandHelper, but it seems that would be pretty
3686 // excessive for a one-off case.
3687 unsigned numAdded = segments.size() - beforeCount;
3688
3689 if (numAdded > 0) {
3691 if (getGangOperandsArgTypeAttr())
3692 llvm::copy(getGangOperandsArgTypeAttr(), std::back_inserter(gangTypes));
3693
3694 for (auto i : llvm::index_range(0u, numAdded)) {
3695 llvm::transform(argTypes, std::back_inserter(gangTypes),
3696 [=](mlir::acc::GangArgType gangTy) {
3697 return mlir::acc::GangArgTypeAttr::get(context, gangTy);
3698 });
3699 (void)i;
3700 }
3701
3702 setGangOperandsArgTypeAttr(mlir::ArrayAttr::get(context, gangTypes));
3703 }
3704}
3705
3706void acc::LoopOp::addPrivatization(MLIRContext *context,
3707 mlir::acc::PrivateOp op,
3708 mlir::acc::PrivateRecipeOp recipe) {
3709 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3710 getPrivateOperandsMutable().append(op.getResult());
3711}
3712
3713void acc::LoopOp::addFirstPrivatization(
3714 MLIRContext *context, mlir::acc::FirstprivateOp op,
3715 mlir::acc::FirstprivateRecipeOp recipe) {
3716 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3717 getFirstprivateOperandsMutable().append(op.getResult());
3718}
3719
3720void acc::LoopOp::addReduction(MLIRContext *context, mlir::acc::ReductionOp op,
3721 mlir::acc::ReductionRecipeOp recipe) {
3722 op.setRecipeAttr(mlir::SymbolRefAttr::get(context, recipe.getSymName()));
3723 getReductionOperandsMutable().append(op.getResult());
3724}
3725
3726//===----------------------------------------------------------------------===//
3727// DataOp
3728//===----------------------------------------------------------------------===//
3729
3730LogicalResult acc::DataOp::verify() {
3731 // 2.6.5. Data Construct restriction
3732 // At least one copy, copyin, copyout, create, no_create, present, deviceptr,
3733 // attach, or default clause must appear on a data construct.
3734 if (getOperands().empty() && !getDefaultAttr())
3735 return emitError("at least one operand or the default attribute "
3736 "must appear on the data operation");
3737
3738 for (mlir::Value operand : getDataClauseOperands())
3739 if (isa<BlockArgument>(operand) ||
3740 !mlir::isa<acc::AttachOp, acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
3741 acc::DeleteOp, acc::DetachOp, acc::DevicePtrOp,
3742 acc::GetDevicePtrOp, acc::NoCreateOp, acc::PresentOp>(
3743 operand.getDefiningOp()))
3744 return emitError("expect data entry/exit operation or acc.getdeviceptr "
3745 "as defining op");
3746
3748 return failure();
3749
3750 return success();
3751}
3752
3753unsigned DataOp::getNumDataOperands() { return getDataClauseOperands().size(); }
3754
3755Value DataOp::getDataOperand(unsigned i) {
3756 unsigned numOptional = getIfCond() ? 1 : 0;
3757 numOptional += getAsyncOperands().size() ? 1 : 0;
3758 numOptional += getWaitOperands().size();
3759 return getOperand(numOptional + i);
3760}
3761
3762bool acc::DataOp::hasAsyncOnly() {
3763 return hasAsyncOnly(mlir::acc::DeviceType::None);
3764}
3765
3766bool acc::DataOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
3767 return hasDeviceType(getAsyncOnly(), deviceType);
3768}
3769
3770mlir::Value DataOp::getAsyncValue() {
3771 return getAsyncValue(mlir::acc::DeviceType::None);
3772}
3773
3774mlir::Value DataOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
3776 getAsyncOperands(), deviceType);
3777}
3778
3779bool DataOp::hasWaitOnly() { return hasWaitOnly(mlir::acc::DeviceType::None); }
3780
3781bool DataOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
3782 return hasDeviceType(getWaitOnly(), deviceType);
3783}
3784
3785mlir::Operation::operand_range DataOp::getWaitValues() {
3786 return getWaitValues(mlir::acc::DeviceType::None);
3787}
3788
3790DataOp::getWaitValues(mlir::acc::DeviceType deviceType) {
3792 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
3793 getHasWaitDevnum(), deviceType);
3794}
3795
3796mlir::Value DataOp::getWaitDevnum() {
3797 return getWaitDevnum(mlir::acc::DeviceType::None);
3798}
3799
3800mlir::Value DataOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
3801 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
3802 getWaitOperandsSegments(), getHasWaitDevnum(),
3803 deviceType);
3804}
3805
3806void acc::DataOp::addAsyncOnly(
3807 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3808 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
3809 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
3810}
3811
3812void acc::DataOp::addAsyncOperand(
3813 MLIRContext *context, mlir::Value newValue,
3814 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3815 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3816 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
3817 getAsyncOperandsMutable()));
3818}
3819
3820void acc::DataOp::addWaitOnly(MLIRContext *context,
3821 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3822 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
3823 effectiveDeviceTypes));
3824}
3825
3826void acc::DataOp::addWaitOperands(
3827 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3828 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3829
3831 if (getWaitOperandsSegments())
3832 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
3833
3834 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
3835 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
3836 getWaitOperandsMutable(), segments));
3837 setWaitOperandsSegments(segments);
3838
3840 if (getHasWaitDevnumAttr())
3841 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
3842 hasDevnums.insert(
3843 hasDevnums.end(),
3844 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
3845 mlir::BoolAttr::get(context, hasDevnum));
3846 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
3847}
3848
3849//===----------------------------------------------------------------------===//
3850// ExitDataOp
3851//===----------------------------------------------------------------------===//
3852
3853LogicalResult acc::ExitDataOp::verify() {
3854 // 2.6.6. Data Exit Directive restriction
3855 // At least one copyout, delete, or detach clause must appear on an exit data
3856 // directive.
3857 if (getDataClauseOperands().empty())
3858 return emitError("at least one operand must be present in dataOperands on "
3859 "the exit data operation");
3860
3861 // The async attribute represent the async clause without value. Therefore the
3862 // attribute and operand cannot appear at the same time.
3863 if (getAsyncOperand() && getAsync())
3864 return emitError("async attribute cannot appear with asyncOperand");
3865
3866 // The wait attribute represent the wait clause without values. Therefore the
3867 // attribute and operands cannot appear at the same time.
3868 if (!getWaitOperands().empty() && getWait())
3869 return emitError("wait attribute cannot appear with waitOperands");
3870
3871 if (getWaitDevnum() && getWaitOperands().empty())
3872 return emitError("wait_devnum cannot appear without waitOperands");
3873
3874 return success();
3875}
3876
3877unsigned ExitDataOp::getNumDataOperands() {
3878 return getDataClauseOperands().size();
3879}
3880
3881Value ExitDataOp::getDataOperand(unsigned i) {
3882 unsigned numOptional = getIfCond() ? 1 : 0;
3883 numOptional += getAsyncOperand() ? 1 : 0;
3884 numOptional += getWaitDevnum() ? 1 : 0;
3885 return getOperand(getWaitOperands().size() + numOptional + i);
3886}
3887
3888void ExitDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3889 MLIRContext *context) {
3890 results.add<RemoveConstantIfCondition<ExitDataOp>>(context);
3891}
3892
3893void ExitDataOp::addAsyncOnly(MLIRContext *context,
3894 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3895 assert(effectiveDeviceTypes.empty());
3896 assert(!getAsyncAttr());
3897 assert(!getAsyncOperand());
3898
3899 setAsyncAttr(mlir::UnitAttr::get(context));
3900}
3901
3902void ExitDataOp::addAsyncOperand(
3903 MLIRContext *context, mlir::Value newValue,
3904 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3905 assert(effectiveDeviceTypes.empty());
3906 assert(!getAsyncAttr());
3907 assert(!getAsyncOperand());
3908
3909 getAsyncOperandMutable().append(newValue);
3910}
3911
3912void ExitDataOp::addWaitOnly(MLIRContext *context,
3913 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3914 assert(effectiveDeviceTypes.empty());
3915 assert(!getWaitAttr());
3916 assert(getWaitOperands().empty());
3917 assert(!getWaitDevnum());
3918
3919 setWaitAttr(mlir::UnitAttr::get(context));
3920}
3921
3922void ExitDataOp::addWaitOperands(
3923 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
3924 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3925 assert(effectiveDeviceTypes.empty());
3926 assert(!getWaitAttr());
3927 assert(getWaitOperands().empty());
3928 assert(!getWaitDevnum());
3929
3930 // if hasDevnum, the first value is the devnum. The 'rest' go into the
3931 // operands list.
3932 if (hasDevnum) {
3933 getWaitDevnumMutable().append(newValues.front());
3934 newValues = newValues.drop_front();
3935 }
3936
3937 getWaitOperandsMutable().append(newValues);
3938}
3939
3940//===----------------------------------------------------------------------===//
3941// EnterDataOp
3942//===----------------------------------------------------------------------===//
3943
3944LogicalResult acc::EnterDataOp::verify() {
3945 // 2.6.6. Data Enter Directive restriction
3946 // At least one copyin, create, or attach clause must appear on an enter data
3947 // directive.
3948 if (getDataClauseOperands().empty())
3949 return emitError("at least one operand must be present in dataOperands on "
3950 "the enter data operation");
3951
3952 // The async attribute represent the async clause without value. Therefore the
3953 // attribute and operand cannot appear at the same time.
3954 if (getAsyncOperand() && getAsync())
3955 return emitError("async attribute cannot appear with asyncOperand");
3956
3957 // The wait attribute represent the wait clause without values. Therefore the
3958 // attribute and operands cannot appear at the same time.
3959 if (!getWaitOperands().empty() && getWait())
3960 return emitError("wait attribute cannot appear with waitOperands");
3961
3962 if (getWaitDevnum() && getWaitOperands().empty())
3963 return emitError("wait_devnum cannot appear without waitOperands");
3964
3965 for (mlir::Value operand : getDataClauseOperands())
3966 if (!mlir::isa<acc::AttachOp, acc::CreateOp, acc::CopyinOp>(
3967 operand.getDefiningOp()))
3968 return emitError("expect data entry operation as defining op");
3969
3970 return success();
3971}
3972
3973unsigned EnterDataOp::getNumDataOperands() {
3974 return getDataClauseOperands().size();
3975}
3976
3977Value EnterDataOp::getDataOperand(unsigned i) {
3978 unsigned numOptional = getIfCond() ? 1 : 0;
3979 numOptional += getAsyncOperand() ? 1 : 0;
3980 numOptional += getWaitDevnum() ? 1 : 0;
3981 return getOperand(getWaitOperands().size() + numOptional + i);
3982}
3983
3984void EnterDataOp::getCanonicalizationPatterns(RewritePatternSet &results,
3985 MLIRContext *context) {
3986 results.add<RemoveConstantIfCondition<EnterDataOp>>(context);
3987}
3988
3989void EnterDataOp::addAsyncOnly(
3990 MLIRContext *context, llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
3991 assert(effectiveDeviceTypes.empty());
3992 assert(!getAsyncAttr());
3993 assert(!getAsyncOperand());
3994
3995 setAsyncAttr(mlir::UnitAttr::get(context));
3996}
3997
3998void EnterDataOp::addAsyncOperand(
3999 MLIRContext *context, mlir::Value newValue,
4000 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4001 assert(effectiveDeviceTypes.empty());
4002 assert(!getAsyncAttr());
4003 assert(!getAsyncOperand());
4004
4005 getAsyncOperandMutable().append(newValue);
4006}
4007
4008void EnterDataOp::addWaitOnly(MLIRContext *context,
4009 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4010 assert(effectiveDeviceTypes.empty());
4011 assert(!getWaitAttr());
4012 assert(getWaitOperands().empty());
4013 assert(!getWaitDevnum());
4014
4015 setWaitAttr(mlir::UnitAttr::get(context));
4016}
4017
4018void EnterDataOp::addWaitOperands(
4019 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4020 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4021 assert(effectiveDeviceTypes.empty());
4022 assert(!getWaitAttr());
4023 assert(getWaitOperands().empty());
4024 assert(!getWaitDevnum());
4025
4026 // if hasDevnum, the first value is the devnum. The 'rest' go into the
4027 // operands list.
4028 if (hasDevnum) {
4029 getWaitDevnumMutable().append(newValues.front());
4030 newValues = newValues.drop_front();
4031 }
4032
4033 getWaitOperandsMutable().append(newValues);
4034}
4035
4036//===----------------------------------------------------------------------===//
4037// AtomicReadOp
4038//===----------------------------------------------------------------------===//
4039
4040LogicalResult AtomicReadOp::verify() { return verifyCommon(); }
4041
4042//===----------------------------------------------------------------------===//
4043// AtomicWriteOp
4044//===----------------------------------------------------------------------===//
4045
4046LogicalResult AtomicWriteOp::verify() { return verifyCommon(); }
4047
4048//===----------------------------------------------------------------------===//
4049// AtomicUpdateOp
4050//===----------------------------------------------------------------------===//
4051
4052LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4053 PatternRewriter &rewriter) {
4054 if (op.isNoOp()) {
4055 rewriter.eraseOp(op);
4056 return success();
4057 }
4058
4059 if (Value writeVal = op.getWriteOpVal()) {
4060 rewriter.replaceOpWithNewOp<AtomicWriteOp>(op, op.getX(), writeVal,
4061 op.getIfCond());
4062 return success();
4063 }
4064
4065 return failure();
4066}
4067
4068LogicalResult AtomicUpdateOp::verify() { return verifyCommon(); }
4069
4070LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4071
4072//===----------------------------------------------------------------------===//
4073// AtomicCaptureOp
4074//===----------------------------------------------------------------------===//
4075
4076AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4077 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4078 return op;
4079 return dyn_cast<AtomicReadOp>(getSecondOp());
4080}
4081
4082AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4083 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4084 return op;
4085 return dyn_cast<AtomicWriteOp>(getSecondOp());
4086}
4087
4088AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4089 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4090 return op;
4091 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4092}
4093
4094LogicalResult AtomicCaptureOp::verifyRegions() { return verifyRegionsCommon(); }
4095
4096//===----------------------------------------------------------------------===//
4097// DeclareEnterOp
4098//===----------------------------------------------------------------------===//
4099
4100template <typename Op>
4101static LogicalResult
4103 bool requireAtLeastOneOperand = true) {
4104 if (operands.empty() && requireAtLeastOneOperand)
4105 return emitError(
4106 op->getLoc(),
4107 "at least one operand must appear on the declare operation");
4108
4109 for (mlir::Value operand : operands) {
4110 if (isa<BlockArgument>(operand) ||
4111 !mlir::isa<acc::CopyinOp, acc::CopyoutOp, acc::CreateOp,
4112 acc::DevicePtrOp, acc::GetDevicePtrOp, acc::PresentOp,
4113 acc::DeclareDeviceResidentOp, acc::DeclareLinkOp>(
4114 operand.getDefiningOp()))
4115 return op.emitError(
4116 "expect valid declare data entry operation or acc.getdeviceptr "
4117 "as defining op");
4118
4119 mlir::Value var{getVar(operand.getDefiningOp())};
4120 assert(var && "declare operands can only be data entry operations which "
4121 "must have var");
4122 (void)var;
4123 std::optional<mlir::acc::DataClause> dataClauseOptional{
4124 getDataClause(operand.getDefiningOp())};
4125 assert(dataClauseOptional.has_value() &&
4126 "declare operands can only be data entry operations which must have "
4127 "dataClause");
4128 (void)dataClauseOptional;
4129 }
4130
4131 return success();
4132}
4133
4134LogicalResult acc::DeclareEnterOp::verify() {
4135 return checkDeclareOperands(*this, this->getDataClauseOperands());
4136}
4137
4138//===----------------------------------------------------------------------===//
4139// DeclareExitOp
4140//===----------------------------------------------------------------------===//
4141
4142LogicalResult acc::DeclareExitOp::verify() {
4143 if (getToken())
4144 return checkDeclareOperands(*this, this->getDataClauseOperands(),
4145 /*requireAtLeastOneOperand=*/false);
4146 return checkDeclareOperands(*this, this->getDataClauseOperands());
4147}
4148
4149//===----------------------------------------------------------------------===//
4150// DeclareOp
4151//===----------------------------------------------------------------------===//
4152
4153LogicalResult acc::DeclareOp::verify() {
4154 return checkDeclareOperands(*this, this->getDataClauseOperands());
4155}
4156
4157//===----------------------------------------------------------------------===//
4158// RoutineOp
4159//===----------------------------------------------------------------------===//
4160
4161static unsigned getParallelismForDeviceType(acc::RoutineOp op,
4162 acc::DeviceType dtype) {
4163 unsigned parallelism = 0;
4164 parallelism += (op.hasGang(dtype) || op.getGangDimValue(dtype)) ? 1 : 0;
4165 parallelism += op.hasWorker(dtype) ? 1 : 0;
4166 parallelism += op.hasVector(dtype) ? 1 : 0;
4167 parallelism += op.hasSeq(dtype) ? 1 : 0;
4168 return parallelism;
4169}
4170
4171LogicalResult acc::RoutineOp::verify() {
4172 unsigned baseParallelism =
4173 getParallelismForDeviceType(*this, acc::DeviceType::None);
4174
4175 if (baseParallelism > 1)
4176 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4177 "be present at the same time";
4178
4179 for (uint32_t dtypeInt = 0; dtypeInt != acc::getMaxEnumValForDeviceType();
4180 ++dtypeInt) {
4181 auto dtype = static_cast<acc::DeviceType>(dtypeInt);
4182 if (dtype == acc::DeviceType::None)
4183 continue;
4184 unsigned parallelism = getParallelismForDeviceType(*this, dtype);
4185
4186 if (parallelism > 1 || (baseParallelism == 1 && parallelism == 1))
4187 return emitError() << "only one of `gang`, `worker`, `vector`, `seq` can "
4188 "be present at the same time for device_type `"
4189 << acc::stringifyDeviceType(dtype) << "`";
4190 }
4191
4192 return success();
4193}
4194
4195static ParseResult parseBindName(OpAsmParser &parser,
4196 mlir::ArrayAttr &bindIdName,
4197 mlir::ArrayAttr &bindStrName,
4198 mlir::ArrayAttr &deviceIdTypes,
4199 mlir::ArrayAttr &deviceStrTypes) {
4200 llvm::SmallVector<mlir::Attribute> bindIdNameAttrs;
4201 llvm::SmallVector<mlir::Attribute> bindStrNameAttrs;
4202 llvm::SmallVector<mlir::Attribute> deviceIdTypeAttrs;
4203 llvm::SmallVector<mlir::Attribute> deviceStrTypeAttrs;
4204
4205 if (failed(parser.parseCommaSeparatedList([&]() {
4206 mlir::Attribute newAttr;
4207 bool isSymbolRefAttr;
4208 auto parseResult = parser.parseAttribute(newAttr);
4209 if (auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(newAttr)) {
4210 bindIdNameAttrs.push_back(symbolRefAttr);
4211 isSymbolRefAttr = true;
4212 } else if (auto stringAttr = dyn_cast<mlir::StringAttr>(newAttr)) {
4213 bindStrNameAttrs.push_back(stringAttr);
4214 isSymbolRefAttr = false;
4215 }
4216 if (parseResult)
4217 return failure();
4218 if (failed(parser.parseOptionalLSquare())) {
4219 if (isSymbolRefAttr) {
4220 deviceIdTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4221 parser.getContext(), mlir::acc::DeviceType::None));
4222 } else {
4223 deviceStrTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4224 parser.getContext(), mlir::acc::DeviceType::None));
4225 }
4226 } else {
4227 if (isSymbolRefAttr) {
4228 if (parser.parseAttribute(deviceIdTypeAttrs.emplace_back()) ||
4229 parser.parseRSquare())
4230 return failure();
4231 } else {
4232 if (parser.parseAttribute(deviceStrTypeAttrs.emplace_back()) ||
4233 parser.parseRSquare())
4234 return failure();
4235 }
4236 }
4237 return success();
4238 })))
4239 return failure();
4240
4241 bindIdName = ArrayAttr::get(parser.getContext(), bindIdNameAttrs);
4242 bindStrName = ArrayAttr::get(parser.getContext(), bindStrNameAttrs);
4243 deviceIdTypes = ArrayAttr::get(parser.getContext(), deviceIdTypeAttrs);
4244 deviceStrTypes = ArrayAttr::get(parser.getContext(), deviceStrTypeAttrs);
4245
4246 return success();
4247}
4248
4250 std::optional<mlir::ArrayAttr> bindIdName,
4251 std::optional<mlir::ArrayAttr> bindStrName,
4252 std::optional<mlir::ArrayAttr> deviceIdTypes,
4253 std::optional<mlir::ArrayAttr> deviceStrTypes) {
4254 // Create combined vectors for all bind names and device types
4257
4258 // Append bindIdName and deviceIdTypes
4259 if (hasDeviceTypeValues(deviceIdTypes)) {
4260 allBindNames.append(bindIdName->begin(), bindIdName->end());
4261 allDeviceTypes.append(deviceIdTypes->begin(), deviceIdTypes->end());
4262 }
4263
4264 // Append bindStrName and deviceStrTypes
4265 if (hasDeviceTypeValues(deviceStrTypes)) {
4266 allBindNames.append(bindStrName->begin(), bindStrName->end());
4267 allDeviceTypes.append(deviceStrTypes->begin(), deviceStrTypes->end());
4268 }
4269
4270 // Print the combined sequence
4271 if (!allBindNames.empty())
4272 llvm::interleaveComma(llvm::zip(allBindNames, allDeviceTypes), p,
4273 [&](const auto &pair) {
4274 p << std::get<0>(pair);
4275 printSingleDeviceType(p, std::get<1>(pair));
4276 });
4277}
4278
4279static ParseResult parseRoutineGangClause(OpAsmParser &parser,
4280 mlir::ArrayAttr &gang,
4281 mlir::ArrayAttr &gangDim,
4282 mlir::ArrayAttr &gangDimDeviceTypes) {
4283
4284 llvm::SmallVector<mlir::Attribute> gangAttrs, gangDimAttrs,
4285 gangDimDeviceTypeAttrs;
4286 bool needCommaBeforeOperands = false;
4287
4288 // Gang keyword only
4289 if (failed(parser.parseOptionalLParen())) {
4290 gangAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4291 parser.getContext(), mlir::acc::DeviceType::None));
4292 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4293 return success();
4294 }
4295
4296 // Parse keyword only attributes
4297 if (succeeded(parser.parseOptionalLSquare())) {
4298 if (failed(parser.parseCommaSeparatedList([&]() {
4299 if (parser.parseAttribute(gangAttrs.emplace_back()))
4300 return failure();
4301 return success();
4302 })))
4303 return failure();
4304 if (parser.parseRSquare())
4305 return failure();
4306 needCommaBeforeOperands = true;
4307 }
4308
4309 if (needCommaBeforeOperands && failed(parser.parseComma()))
4310 return failure();
4311
4312 if (failed(parser.parseCommaSeparatedList([&]() {
4313 if (parser.parseKeyword(acc::RoutineOp::getGangDimKeyword()) ||
4314 parser.parseColon() ||
4315 parser.parseAttribute(gangDimAttrs.emplace_back()))
4316 return failure();
4317 if (succeeded(parser.parseOptionalLSquare())) {
4318 if (parser.parseAttribute(gangDimDeviceTypeAttrs.emplace_back()) ||
4319 parser.parseRSquare())
4320 return failure();
4321 } else {
4322 gangDimDeviceTypeAttrs.push_back(mlir::acc::DeviceTypeAttr::get(
4323 parser.getContext(), mlir::acc::DeviceType::None));
4324 }
4325 return success();
4326 })))
4327 return failure();
4328
4329 if (failed(parser.parseRParen()))
4330 return failure();
4331
4332 gang = ArrayAttr::get(parser.getContext(), gangAttrs);
4333 gangDim = ArrayAttr::get(parser.getContext(), gangDimAttrs);
4334 gangDimDeviceTypes =
4335 ArrayAttr::get(parser.getContext(), gangDimDeviceTypeAttrs);
4336
4337 return success();
4338}
4339
4341 std::optional<mlir::ArrayAttr> gang,
4342 std::optional<mlir::ArrayAttr> gangDim,
4343 std::optional<mlir::ArrayAttr> gangDimDeviceTypes) {
4344
4345 if (!hasDeviceTypeValues(gangDimDeviceTypes) && hasDeviceTypeValues(gang) &&
4346 gang->size() == 1) {
4347 auto deviceTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*gang)[0]);
4348 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4349 return;
4350 }
4351
4352 p << "(";
4353
4354 printDeviceTypes(p, gang);
4355
4356 if (hasDeviceTypeValues(gang) && hasDeviceTypeValues(gangDimDeviceTypes))
4357 p << ", ";
4358
4359 if (hasDeviceTypeValues(gangDimDeviceTypes))
4360 llvm::interleaveComma(llvm::zip(*gangDim, *gangDimDeviceTypes), p,
4361 [&](const auto &pair) {
4362 p << acc::RoutineOp::getGangDimKeyword() << ": ";
4363 p << std::get<0>(pair);
4364 printSingleDeviceType(p, std::get<1>(pair));
4365 });
4366
4367 p << ")";
4368}
4369
4370static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser,
4371 mlir::ArrayAttr &deviceTypes) {
4373 // Keyword only
4374 if (failed(parser.parseOptionalLParen())) {
4375 attributes.push_back(mlir::acc::DeviceTypeAttr::get(
4376 parser.getContext(), mlir::acc::DeviceType::None));
4377 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4378 return success();
4379 }
4380
4381 // Parse device type attributes
4382 if (succeeded(parser.parseOptionalLSquare())) {
4383 if (failed(parser.parseCommaSeparatedList([&]() {
4384 if (parser.parseAttribute(attributes.emplace_back()))
4385 return failure();
4386 return success();
4387 })))
4388 return failure();
4389 if (parser.parseRSquare() || parser.parseRParen())
4390 return failure();
4391 }
4392 deviceTypes = ArrayAttr::get(parser.getContext(), attributes);
4393 return success();
4394}
4395
4396static void
4398 std::optional<mlir::ArrayAttr> deviceTypes) {
4399
4400 if (hasDeviceTypeValues(deviceTypes) && deviceTypes->size() == 1) {
4401 auto deviceTypeAttr =
4402 mlir::dyn_cast<mlir::acc::DeviceTypeAttr>((*deviceTypes)[0]);
4403 if (deviceTypeAttr.getValue() == mlir::acc::DeviceType::None)
4404 return;
4405 }
4406
4407 if (!hasDeviceTypeValues(deviceTypes))
4408 return;
4409
4410 p << "([";
4411 llvm::interleaveComma(*deviceTypes, p, [&](mlir::Attribute attr) {
4412 auto dTypeAttr = mlir::dyn_cast<mlir::acc::DeviceTypeAttr>(attr);
4413 p << dTypeAttr;
4414 });
4415 p << "])";
4416}
4417
4418bool RoutineOp::hasWorker() { return hasWorker(mlir::acc::DeviceType::None); }
4419
4420bool RoutineOp::hasWorker(mlir::acc::DeviceType deviceType) {
4421 return hasDeviceType(getWorker(), deviceType);
4422}
4423
4424bool RoutineOp::hasVector() { return hasVector(mlir::acc::DeviceType::None); }
4425
4426bool RoutineOp::hasVector(mlir::acc::DeviceType deviceType) {
4427 return hasDeviceType(getVector(), deviceType);
4428}
4429
4430bool RoutineOp::hasSeq() { return hasSeq(mlir::acc::DeviceType::None); }
4431
4432bool RoutineOp::hasSeq(mlir::acc::DeviceType deviceType) {
4433 return hasDeviceType(getSeq(), deviceType);
4434}
4435
4436std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4437RoutineOp::getBindNameValue() {
4438 return getBindNameValue(mlir::acc::DeviceType::None);
4439}
4440
4441std::optional<std::variant<mlir::SymbolRefAttr, mlir::StringAttr>>
4442RoutineOp::getBindNameValue(mlir::acc::DeviceType deviceType) {
4443 if (!hasDeviceTypeValues(getBindIdNameDeviceType()) &&
4444 !hasDeviceTypeValues(getBindStrNameDeviceType())) {
4445 return std::nullopt;
4446 }
4447
4448 if (auto pos = findSegment(*getBindIdNameDeviceType(), deviceType)) {
4449 auto attr = (*getBindIdName())[*pos];
4450 auto symbolRefAttr = dyn_cast<mlir::SymbolRefAttr>(attr);
4451 assert(symbolRefAttr && "expected SymbolRef");
4452 return symbolRefAttr;
4453 }
4454
4455 if (auto pos = findSegment(*getBindStrNameDeviceType(), deviceType)) {
4456 auto attr = (*getBindStrName())[*pos];
4457 auto stringAttr = dyn_cast<mlir::StringAttr>(attr);
4458 assert(stringAttr && "expected String");
4459 return stringAttr;
4460 }
4461
4462 return std::nullopt;
4463}
4464
4465bool RoutineOp::hasGang() { return hasGang(mlir::acc::DeviceType::None); }
4466
4467bool RoutineOp::hasGang(mlir::acc::DeviceType deviceType) {
4468 return hasDeviceType(getGang(), deviceType);
4469}
4470
4471std::optional<int64_t> RoutineOp::getGangDimValue() {
4472 return getGangDimValue(mlir::acc::DeviceType::None);
4473}
4474
4475std::optional<int64_t>
4476RoutineOp::getGangDimValue(mlir::acc::DeviceType deviceType) {
4477 if (!hasDeviceTypeValues(getGangDimDeviceType()))
4478 return std::nullopt;
4479 if (auto pos = findSegment(*getGangDimDeviceType(), deviceType)) {
4480 auto intAttr = mlir::dyn_cast<mlir::IntegerAttr>((*getGangDim())[*pos]);
4481 return intAttr.getInt();
4482 }
4483 return std::nullopt;
4484}
4485
4486void RoutineOp::addSeq(MLIRContext *context,
4487 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4488 setSeqAttr(addDeviceTypeAffectedOperandHelper(context, getSeqAttr(),
4489 effectiveDeviceTypes));
4490}
4491
4492void RoutineOp::addVector(MLIRContext *context,
4493 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4494 setVectorAttr(addDeviceTypeAffectedOperandHelper(context, getVectorAttr(),
4495 effectiveDeviceTypes));
4496}
4497
4498void RoutineOp::addWorker(MLIRContext *context,
4499 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4500 setWorkerAttr(addDeviceTypeAffectedOperandHelper(context, getWorkerAttr(),
4501 effectiveDeviceTypes));
4502}
4503
4504void RoutineOp::addGang(MLIRContext *context,
4505 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4506 setGangAttr(addDeviceTypeAffectedOperandHelper(context, getGangAttr(),
4507 effectiveDeviceTypes));
4508}
4509
4510void RoutineOp::addGang(MLIRContext *context,
4511 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4512 uint64_t val) {
4515
4516 if (getGangDimAttr())
4517 llvm::copy(getGangDimAttr(), std::back_inserter(dimValues));
4518 if (getGangDimDeviceTypeAttr())
4519 llvm::copy(getGangDimDeviceTypeAttr(), std::back_inserter(deviceTypes));
4520
4521 assert(dimValues.size() == deviceTypes.size());
4522
4523 if (effectiveDeviceTypes.empty()) {
4524 dimValues.push_back(
4525 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4526 deviceTypes.push_back(
4527 acc::DeviceTypeAttr::get(context, acc::DeviceType::None));
4528 } else {
4529 for (DeviceType dt : effectiveDeviceTypes) {
4530 dimValues.push_back(
4531 mlir::IntegerAttr::get(mlir::IntegerType::get(context, 64), val));
4532 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, dt));
4533 }
4534 }
4535 assert(dimValues.size() == deviceTypes.size());
4536
4537 setGangDimAttr(mlir::ArrayAttr::get(context, dimValues));
4538 setGangDimDeviceTypeAttr(mlir::ArrayAttr::get(context, deviceTypes));
4539}
4540
4541void RoutineOp::addBindStrName(MLIRContext *context,
4542 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4543 mlir::StringAttr val) {
4544 unsigned before = getBindStrNameDeviceTypeAttr()
4545 ? getBindStrNameDeviceTypeAttr().size()
4546 : 0;
4547
4548 setBindStrNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4549 context, getBindStrNameDeviceTypeAttr(), effectiveDeviceTypes));
4550 unsigned after = getBindStrNameDeviceTypeAttr().size();
4551
4553 if (getBindStrNameAttr())
4554 llvm::copy(getBindStrNameAttr(), std::back_inserter(vals));
4555 for (unsigned i = 0; i < after - before; ++i)
4556 vals.push_back(val);
4557
4558 setBindStrNameAttr(mlir::ArrayAttr::get(context, vals));
4559}
4560
4561void RoutineOp::addBindIDName(MLIRContext *context,
4562 llvm::ArrayRef<DeviceType> effectiveDeviceTypes,
4563 mlir::SymbolRefAttr val) {
4564 unsigned before =
4565 getBindIdNameDeviceTypeAttr() ? getBindIdNameDeviceTypeAttr().size() : 0;
4566
4567 setBindIdNameDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4568 context, getBindIdNameDeviceTypeAttr(), effectiveDeviceTypes));
4569 unsigned after = getBindIdNameDeviceTypeAttr().size();
4570
4572 if (getBindIdNameAttr())
4573 llvm::copy(getBindIdNameAttr(), std::back_inserter(vals));
4574 for (unsigned i = 0; i < after - before; ++i)
4575 vals.push_back(val);
4576
4577 setBindIdNameAttr(mlir::ArrayAttr::get(context, vals));
4578}
4579
4580//===----------------------------------------------------------------------===//
4581// InitOp
4582//===----------------------------------------------------------------------===//
4583
4584LogicalResult acc::InitOp::verify() {
4585 Operation *currOp = *this;
4586 while ((currOp = currOp->getParentOp()))
4587 if (isComputeOperation(currOp))
4588 return emitOpError("cannot be nested in a compute operation");
4589 return success();
4590}
4591
4592void acc::InitOp::addDeviceType(MLIRContext *context,
4593 mlir::acc::DeviceType deviceType) {
4595 if (getDeviceTypesAttr())
4596 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4597
4598 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4599 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4600}
4601
4602//===----------------------------------------------------------------------===//
4603// ShutdownOp
4604//===----------------------------------------------------------------------===//
4605
4606LogicalResult acc::ShutdownOp::verify() {
4607 Operation *currOp = *this;
4608 while ((currOp = currOp->getParentOp()))
4609 if (isComputeOperation(currOp))
4610 return emitOpError("cannot be nested in a compute operation");
4611 return success();
4612}
4613
4614void acc::ShutdownOp::addDeviceType(MLIRContext *context,
4615 mlir::acc::DeviceType deviceType) {
4617 if (getDeviceTypesAttr())
4618 llvm::copy(getDeviceTypesAttr(), std::back_inserter(deviceTypes));
4619
4620 deviceTypes.push_back(acc::DeviceTypeAttr::get(context, deviceType));
4621 setDeviceTypesAttr(mlir::ArrayAttr::get(context, deviceTypes));
4622}
4623
4624//===----------------------------------------------------------------------===//
4625// SetOp
4626//===----------------------------------------------------------------------===//
4627
4628LogicalResult acc::SetOp::verify() {
4629 Operation *currOp = *this;
4630 while ((currOp = currOp->getParentOp()))
4631 if (isComputeOperation(currOp))
4632 return emitOpError("cannot be nested in a compute operation");
4633 if (!getDeviceTypeAttr() && !getDefaultAsync() && !getDeviceNum())
4634 return emitOpError("at least one default_async, device_num, or device_type "
4635 "operand must appear");
4636 return success();
4637}
4638
4639//===----------------------------------------------------------------------===//
4640// UpdateOp
4641//===----------------------------------------------------------------------===//
4642
4643LogicalResult acc::UpdateOp::verify() {
4644 // At least one of host or device should have a value.
4645 if (getDataClauseOperands().empty())
4646 return emitError("at least one value must be present in dataOperands");
4647
4649 getAsyncOperandsDeviceTypeAttr(),
4650 "async")))
4651 return failure();
4652
4654 *this, getWaitOperands(), getWaitOperandsSegmentsAttr(),
4655 getWaitOperandsDeviceTypeAttr(), "wait")))
4656 return failure();
4657
4659 return failure();
4660
4661 for (mlir::Value operand : getDataClauseOperands())
4662 if (!mlir::isa<acc::UpdateDeviceOp, acc::UpdateHostOp, acc::GetDevicePtrOp>(
4663 operand.getDefiningOp()))
4664 return emitError("expect data entry/exit operation or acc.getdeviceptr "
4665 "as defining op");
4666
4667 return success();
4668}
4669
4670unsigned UpdateOp::getNumDataOperands() {
4671 return getDataClauseOperands().size();
4672}
4673
4674Value UpdateOp::getDataOperand(unsigned i) {
4675 unsigned numOptional = getAsyncOperands().size();
4676 numOptional += getIfCond() ? 1 : 0;
4677 return getOperand(getWaitOperands().size() + numOptional + i);
4678}
4679
4680void UpdateOp::getCanonicalizationPatterns(RewritePatternSet &results,
4681 MLIRContext *context) {
4682 results.add<RemoveConstantIfCondition<UpdateOp>>(context);
4683}
4684
4685bool UpdateOp::hasAsyncOnly() {
4686 return hasAsyncOnly(mlir::acc::DeviceType::None);
4687}
4688
4689bool UpdateOp::hasAsyncOnly(mlir::acc::DeviceType deviceType) {
4690 return hasDeviceType(getAsyncOnly(), deviceType);
4691}
4692
4693mlir::Value UpdateOp::getAsyncValue() {
4694 return getAsyncValue(mlir::acc::DeviceType::None);
4695}
4696
4697mlir::Value UpdateOp::getAsyncValue(mlir::acc::DeviceType deviceType) {
4699 return {};
4700
4701 if (auto pos = findSegment(*getAsyncOperandsDeviceType(), deviceType))
4702 return getAsyncOperands()[*pos];
4703
4704 return {};
4705}
4706
4707bool UpdateOp::hasWaitOnly() {
4708 return hasWaitOnly(mlir::acc::DeviceType::None);
4709}
4710
4711bool UpdateOp::hasWaitOnly(mlir::acc::DeviceType deviceType) {
4712 return hasDeviceType(getWaitOnly(), deviceType);
4713}
4714
4715mlir::Operation::operand_range UpdateOp::getWaitValues() {
4716 return getWaitValues(mlir::acc::DeviceType::None);
4717}
4718
4720UpdateOp::getWaitValues(mlir::acc::DeviceType deviceType) {
4722 getWaitOperandsDeviceType(), getWaitOperands(), getWaitOperandsSegments(),
4723 getHasWaitDevnum(), deviceType);
4724}
4725
4726mlir::Value UpdateOp::getWaitDevnum() {
4727 return getWaitDevnum(mlir::acc::DeviceType::None);
4728}
4729
4730mlir::Value UpdateOp::getWaitDevnum(mlir::acc::DeviceType deviceType) {
4731 return getWaitDevnumValue(getWaitOperandsDeviceType(), getWaitOperands(),
4732 getWaitOperandsSegments(), getHasWaitDevnum(),
4733 deviceType);
4734}
4735
4736void UpdateOp::addAsyncOnly(MLIRContext *context,
4737 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4738 setAsyncOnlyAttr(addDeviceTypeAffectedOperandHelper(
4739 context, getAsyncOnlyAttr(), effectiveDeviceTypes));
4740}
4741
4742void UpdateOp::addAsyncOperand(
4743 MLIRContext *context, mlir::Value newValue,
4744 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4745 setAsyncOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4746 context, getAsyncOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValue,
4747 getAsyncOperandsMutable()));
4748}
4749
4750void UpdateOp::addWaitOnly(MLIRContext *context,
4751 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4752 setWaitOnlyAttr(addDeviceTypeAffectedOperandHelper(context, getWaitOnlyAttr(),
4753 effectiveDeviceTypes));
4754}
4755
4756void UpdateOp::addWaitOperands(
4757 MLIRContext *context, bool hasDevnum, mlir::ValueRange newValues,
4758 llvm::ArrayRef<DeviceType> effectiveDeviceTypes) {
4759
4761 if (getWaitOperandsSegments())
4762 llvm::copy(*getWaitOperandsSegments(), std::back_inserter(segments));
4763
4764 setWaitOperandsDeviceTypeAttr(addDeviceTypeAffectedOperandHelper(
4765 context, getWaitOperandsDeviceTypeAttr(), effectiveDeviceTypes, newValues,
4766 getWaitOperandsMutable(), segments));
4767 setWaitOperandsSegments(segments);
4768
4770 if (getHasWaitDevnumAttr())
4771 llvm::copy(getHasWaitDevnumAttr(), std::back_inserter(hasDevnums));
4772 hasDevnums.insert(
4773 hasDevnums.end(),
4774 std::max(effectiveDeviceTypes.size(), static_cast<size_t>(1)),
4775 mlir::BoolAttr::get(context, hasDevnum));
4776 setHasWaitDevnumAttr(mlir::ArrayAttr::get(context, hasDevnums));
4777}
4778
4779//===----------------------------------------------------------------------===//
4780// WaitOp
4781//===----------------------------------------------------------------------===//
4782
4783LogicalResult acc::WaitOp::verify() {
4784 // The async attribute represent the async clause without value. Therefore the
4785 // attribute and operand cannot appear at the same time.
4786 if (getAsyncOperand() && getAsync())
4787 return emitError("async attribute cannot appear with asyncOperand");
4788
4789 if (getWaitDevnum() && getWaitOperands().empty())
4790 return emitError("wait_devnum cannot appear without waitOperands");
4791
4792 return success();
4793}
4794
4795#define GET_OP_CLASSES
4796#include "mlir/Dialect/OpenACC/OpenACCOps.cpp.inc"
4797
4798#define GET_ATTRDEF_CLASSES
4799#include "mlir/Dialect/OpenACC/OpenACCOpsAttributes.cpp.inc"
4800
4801#define GET_TYPEDEF_CLASSES
4802#include "mlir/Dialect/OpenACC/OpenACCOpsTypes.cpp.inc"
4803
4804//===----------------------------------------------------------------------===//
4805// acc dialect utilities
4806//===----------------------------------------------------------------------===//
4807
4810 auto varPtr{llvm::TypeSwitch<mlir::Operation *,
4812 accDataClauseOp)
4813 .Case<ACC_DATA_ENTRY_OPS>(
4814 [&](auto entry) { return entry.getVarPtr(); })
4815 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4816 [&](auto exit) { return exit.getVarPtr(); })
4817 .Default([&](mlir::Operation *) {
4819 })};
4820 return varPtr;
4821}
4822
4824 auto varPtr{
4826 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getVar(); })
4827 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4828 return varPtr;
4829}
4830
4832 auto varType{llvm::TypeSwitch<mlir::Operation *, mlir::Type>(accDataClauseOp)
4833 .Case<ACC_DATA_ENTRY_OPS>(
4834 [&](auto entry) { return entry.getVarType(); })
4835 .Case<mlir::acc::CopyoutOp, mlir::acc::UpdateHostOp>(
4836 [&](auto exit) { return exit.getVarType(); })
4837 .Default([&](mlir::Operation *) { return mlir::Type(); })};
4838 return varType;
4839}
4840
4843 auto accPtr{llvm::TypeSwitch<mlir::Operation *,
4845 accDataClauseOp)
4846 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>(
4847 [&](auto dataClause) { return dataClause.getAccPtr(); })
4848 .Default([&](mlir::Operation *) {
4850 })};
4851 return accPtr;
4852}
4853
4855 auto accPtr{llvm::TypeSwitch<mlir::Operation *, mlir::Value>(accDataClauseOp)
4857 [&](auto dataClause) { return dataClause.getAccVar(); })
4858 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4859 return accPtr;
4860}
4861
4863 auto varPtrPtr{
4865 .Case<ACC_DATA_ENTRY_OPS>(
4866 [&](auto dataClause) { return dataClause.getVarPtrPtr(); })
4867 .Default([&](mlir::Operation *) { return mlir::Value(); })};
4868 return varPtrPtr;
4869}
4870
4875 accDataClauseOp)
4876 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4878 dataClause.getBounds().begin(), dataClause.getBounds().end());
4879 })
4880 .Default([&](mlir::Operation *) {
4882 })};
4883 return bounds;
4884}
4885
4889 accDataClauseOp)
4890 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4892 dataClause.getAsyncOperands().begin(),
4893 dataClause.getAsyncOperands().end());
4894 })
4895 .Default([&](mlir::Operation *) {
4897 });
4898}
4899
4900mlir::ArrayAttr
4903 .Case<ACC_DATA_ENTRY_OPS, ACC_DATA_EXIT_OPS>([&](auto dataClause) {
4904 return dataClause.getAsyncOperandsDeviceTypeAttr();
4905 })
4906 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4907}
4908
4909mlir::ArrayAttr mlir::acc::getAsyncOnly(mlir::Operation *accDataClauseOp) {
4912 [&](auto dataClause) { return dataClause.getAsyncOnlyAttr(); })
4913 .Default([&](mlir::Operation *) { return mlir::ArrayAttr{}; });
4914}
4915
4916std::optional<llvm::StringRef> mlir::acc::getVarName(mlir::Operation *accOp) {
4917 auto name{
4919 .Case<ACC_DATA_ENTRY_OPS>([&](auto entry) { return entry.getName(); })
4920 .Default([&](mlir::Operation *) -> std::optional<llvm::StringRef> {
4921 return {};
4922 })};
4923 return name;
4924}
4925
4926std::optional<mlir::acc::DataClause>
4928 auto dataClause{
4930 accDataEntryOp)
4931 .Case<ACC_DATA_ENTRY_OPS>(
4932 [&](auto entry) { return entry.getDataClause(); })
4933 .Default([&](mlir::Operation *) { return std::nullopt; })};
4934 return dataClause;
4935}
4936
4938 auto implicit{llvm::TypeSwitch<mlir::Operation *, bool>(accDataEntryOp)
4939 .Case<ACC_DATA_ENTRY_OPS>(
4940 [&](auto entry) { return entry.getImplicit(); })
4941 .Default([&](mlir::Operation *) { return false; })};
4942 return implicit;
4943}
4944
4946 auto dataOperands{
4949 [&](auto entry) { return entry.getDataClauseOperands(); })
4950 .Default([&](mlir::Operation *) { return mlir::ValueRange(); })};
4951 return dataOperands;
4952}
4953
4956 auto dataOperands{
4959 [&](auto entry) { return entry.getDataClauseOperandsMutable(); })
4960 .Default([&](mlir::Operation *) { return nullptr; })};
4961 return dataOperands;
4962}
4963
4964mlir::SymbolRefAttr mlir::acc::getRecipe(mlir::Operation *accOp) {
4965 auto recipe{
4967 .Case<ACC_DATA_ENTRY_OPS>(
4968 [&](auto entry) { return entry.getRecipeAttr(); })
4969 .Default([&](mlir::Operation *) { return mlir::SymbolRefAttr{}; })};
4970 return recipe;
4971}
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op, Region &region, ValueRange blockArgs={})
Replaces the given op with the contents of the given single-block region, using the operands of the b...
Definition SCF.cpp:137
static Type getElementType(Type type)
Determine the element type of type.
static LogicalResult verifyYield(linalg::YieldOp op, LinalgOp linalgOp)
ArrayAttr()
if(!isCopyOut)
b getContext())
void printRoutineGangClause(OpAsmPrinter &p, Operation *op, std::optional< mlir::ArrayAttr > gang, std::optional< mlir::ArrayAttr > gangDim, std::optional< mlir::ArrayAttr > gangDimDeviceTypes)
Definition OpenACC.cpp:4340
static ParseResult parseRegions(OpAsmParser &parser, OperationState &state, unsigned nRegions=1)
Definition OpenACC.cpp:1163
bool hasDuplicateDeviceTypes(std::optional< mlir::ArrayAttr > segments, llvm::SmallSet< mlir::acc::DeviceType, 3 > &deviceTypes)
Definition OpenACC.cpp:3091
static LogicalResult verifyDeviceTypeCountMatch(Op op, OperandRange operands, ArrayAttr deviceTypes, llvm::StringRef keyword)
Definition OpenACC.cpp:1709
static ParseResult parseBindName(OpAsmParser &parser, mlir::ArrayAttr &bindIdName, mlir::ArrayAttr &bindStrName, mlir::ArrayAttr &deviceIdTypes, mlir::ArrayAttr &deviceStrTypes)
Definition OpenACC.cpp:4195
static void printRecipeSym(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::SymbolRefAttr recipeAttr)
Definition OpenACC.cpp:764
static bool isComputeOperation(Operation *op)
Definition OpenACC.cpp:1177
static mlir::Operation::operand_range getWaitValuesWithoutDevnum(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:544
static bool hasOnlyDeviceTypeNone(std::optional< mlir::ArrayAttr > attrs)
Definition OpenACC.cpp:2205
static ParseResult parseRecipeSym(mlir::OpAsmParser &parser, mlir::SymbolRefAttr &recipeAttr)
Definition OpenACC.cpp:757
static void printAccVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value accVar, mlir::Type accVarType)
Definition OpenACC.cpp:697
static mlir::Value getWaitDevnumValue(std::optional< mlir::ArrayAttr > deviceTypeAttr, mlir::Operation::operand_range operands, std::optional< llvm::ArrayRef< int32_t > > segments, std::optional< mlir::ArrayAttr > hasWaitDevnum, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:528
static void printVar(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Value var)
Definition OpenACC.cpp:666
static void printWaitClause(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > hasDevNum, std::optional< mlir::ArrayAttr > keywordOnly)
Definition OpenACC.cpp:2216
static ParseResult parseWaitClause(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &hasDevNum, mlir::ArrayAttr &keywordOnly)
Definition OpenACC.cpp:2121
static bool hasDeviceTypeValues(std::optional< mlir::ArrayAttr > arrayAttr)
Definition OpenACC.cpp:470
static void printDeviceTypeArrayAttr(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:4397
static ParseResult parseGangValue(OpAsmParser &parser, llvm::StringRef keyword, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, llvm::SmallVector< GangArgTypeAttr > &attributes, GangArgTypeAttr gangArgType, bool &needCommaBetweenValues, bool &newValue)
Definition OpenACC.cpp:2900
static ParseResult parseCombinedConstructsLoop(mlir::OpAsmParser &parser, mlir::acc::CombinedConstructsTypeAttr &attr)
Definition OpenACC.cpp:2456
static std::optional< mlir::acc::DeviceType > checkDeviceTypes(mlir::ArrayAttr deviceTypes)
Check for duplicates in the DeviceType array attribute.
Definition OpenACC.cpp:3107
static LogicalResult checkDeclareOperands(Op &op, const mlir::ValueRange &operands, bool requireAtLeastOneOperand=true)
Definition OpenACC.cpp:4102
static LogicalResult checkVarAndAccVar(Op op)
Definition OpenACC.cpp:604
static ParseResult parseOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2410
static void printDeviceTypes(mlir::OpAsmPrinter &p, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:488
static LogicalResult checkVarAndVarType(Op op)
Definition OpenACC.cpp:586
static LogicalResult checkValidModifier(Op op, acc::DataClauseModifier validModifiers)
Definition OpenACC.cpp:620
ParseResult parseLoopControl(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lowerbound, SmallVectorImpl< Type > &lowerboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &upperbound, SmallVectorImpl< Type > &upperboundType, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &step, SmallVectorImpl< Type > &stepType)
loop-control ::= control ( ssa-id-and-type-list ) = ( ssa-id-and-type-list ) to ( ssa-id-and-type-lis...
Definition OpenACC.cpp:3475
static LogicalResult checkDataOperands(Op op, const mlir::ValueRange &operands)
Check dataOperands for acc.parallel, acc.serial and acc.kernels.
Definition OpenACC.cpp:1664
static ParseResult parseDeviceTypeOperands(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:2252
static mlir::Value getValueInDeviceTypeSegment(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:1791
static LogicalResult checkNoModifier(Op op)
Definition OpenACC.cpp:612
static ParseResult parseAccVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var, mlir::Type &accVarType)
Definition OpenACC.cpp:675
static std::optional< unsigned > findSegment(ArrayAttr segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:499
static mlir::Operation::operand_range getValuesFromSegments(std::optional< mlir::ArrayAttr > arrayAttr, mlir::Operation::operand_range range, std::optional< llvm::ArrayRef< int32_t > > segments, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:512
static ParseResult parseNumGangs(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:1991
static void getSingleRegionOpSuccessorRegions(Operation *op, Region &region, RegionBranchPoint point, SmallVectorImpl< RegionSuccessor > &regions)
Generic helper for single-region OpenACC ops that execute their body once and then return to the pare...
Definition OpenACC.cpp:400
static ParseResult parseVar(mlir::OpAsmParser &parser, OpAsmParser::UnresolvedOperand &var)
Definition OpenACC.cpp:651
void printLoopControl(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lowerbound, TypeRange lowerboundType, ValueRange upperbound, TypeRange upperboundType, ValueRange steps, TypeRange stepType)
Definition OpenACC.cpp:3506
static ParseResult parseDeviceTypeArrayAttr(OpAsmParser &parser, mlir::ArrayAttr &deviceTypes)
Definition OpenACC.cpp:4370
static ParseResult parseRoutineGangClause(OpAsmParser &parser, mlir::ArrayAttr &gang, mlir::ArrayAttr &gangDim, mlir::ArrayAttr &gangDimDeviceTypes)
Definition OpenACC.cpp:4279
static void printDeviceTypeOperandsWithSegment(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2104
static void printDeviceTypeOperands(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes)
Definition OpenACC.cpp:2279
static void printOperandWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::Value > operand, mlir::Type operandType, mlir::UnitAttr attr)
Definition OpenACC.cpp:2395
static ParseResult parseDeviceTypeOperandsWithSegment(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::DenseI32ArrayAttr &segments)
Definition OpenACC.cpp:2058
static ParseResult parseOperandWithKeywordOnly(mlir::OpAsmParser &parser, std::optional< OpAsmParser::UnresolvedOperand > &operand, mlir::Type &operandType, mlir::UnitAttr &attr)
Definition OpenACC.cpp:2371
static void printVarPtrType(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::Type varPtrType, mlir::TypeAttr varTypeAttr)
Definition OpenACC.cpp:738
static ParseResult parseGangClause(OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &gangOperands, llvm::SmallVectorImpl< Type > &gangOperandsType, mlir::ArrayAttr &gangArgType, mlir::ArrayAttr &deviceType, mlir::DenseI32ArrayAttr &segments, mlir::ArrayAttr &gangOnlyDeviceType)
Definition OpenACC.cpp:2919
static LogicalResult verifyInitLikeSingleArgRegion(Operation *op, Region &region, StringRef regionType, StringRef regionName, Type type, bool verifyYield, bool optional=false)
Definition OpenACC.cpp:1444
static void printOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, mlir::UnitAttr attr)
Definition OpenACC.cpp:2440
static void printSingleDeviceType(mlir::OpAsmPrinter &p, mlir::Attribute attr)
Definition OpenACC.cpp:2035
static LogicalResult checkRecipe(OpT op, llvm::StringRef operandName)
Definition OpenACC.cpp:630
static LogicalResult checkPrivateOperands(mlir::Operation *accConstructOp, const mlir::ValueRange &operands, llvm::StringRef operandName)
Definition OpenACC.cpp:1678
static void printDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::ArrayAttr > keywordOnlyDeviceTypes)
Definition OpenACC.cpp:2352
static bool hasDeviceType(std::optional< mlir::ArrayAttr > arrayAttr, mlir::acc::DeviceType deviceType)
Definition OpenACC.cpp:474
void printGangClause(OpAsmPrinter &p, Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > gangArgTypes, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments, std::optional< mlir::ArrayAttr > gangOnlyDeviceTypes)
Definition OpenACC.cpp:3046
static ParseResult parseDeviceTypeOperandsWithKeywordOnly(mlir::OpAsmParser &parser, llvm::SmallVectorImpl< mlir::OpAsmParser::UnresolvedOperand > &operands, llvm::SmallVectorImpl< Type > &types, mlir::ArrayAttr &deviceTypes, mlir::ArrayAttr &keywordOnlyDeviceType)
Definition OpenACC.cpp:2290
static ParseResult parseVarPtrType(mlir::OpAsmParser &parser, mlir::Type &varPtrType, mlir::TypeAttr &varTypeAttr)
Definition OpenACC.cpp:709
static LogicalResult checkWaitAndAsyncConflict(Op op)
Definition OpenACC.cpp:564
static LogicalResult verifyDeviceTypeAndSegmentCountMatch(Op op, OperandRange operands, DenseI32ArrayAttr segments, ArrayAttr deviceTypes, llvm::StringRef keyword, int32_t maxInSegment=0)
Definition OpenACC.cpp:1719
static unsigned getParallelismForDeviceType(acc::RoutineOp op, acc::DeviceType dtype)
Definition OpenACC.cpp:4161
static void printNumGangs(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::OperandRange operands, mlir::TypeRange types, std::optional< mlir::ArrayAttr > deviceTypes, std::optional< mlir::DenseI32ArrayAttr > segments)
Definition OpenACC.cpp:2041
static void printCombinedConstructsLoop(mlir::OpAsmPrinter &p, mlir::Operation *op, mlir::acc::CombinedConstructsTypeAttr attr)
Definition OpenACC.cpp:2476
static void printBindName(mlir::OpAsmPrinter &p, mlir::Operation *op, std::optional< mlir::ArrayAttr > bindIdName, std::optional< mlir::ArrayAttr > bindStrName, std::optional< mlir::ArrayAttr > deviceIdTypes, std::optional< mlir::ArrayAttr > deviceStrTypes)
Definition OpenACC.cpp:4249
#define ACC_COMPUTE_AND_DATA_CONSTRUCT_OPS
Definition OpenACC.h:68
#define ACC_DATA_ENTRY_OPS
Definition OpenACC.h:45
#define ACC_DATA_EXIT_OPS
Definition OpenACC.h:53
false
Parses a map_entries map type from a string format back into its numeric value.
static void genStore(OpBuilder &builder, Location loc, Value val, Value mem, Value idx)
Generates a store with proper index typing and proper value.
static Value genLoad(OpBuilder &builder, Location loc, Value mem, Value idx)
Generates a load with proper index typing.
virtual ParseResult parseLBrace()=0
Parse a { token.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseRSquare()=0
Parse a ] token.
virtual ParseResult parseRBrace()=0
Parse a } token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
virtual ParseResult parseOptionalLSquare()=0
Parse a [ token if present.
virtual ParseResult parseAttribute(Attribute &result, Type type={})=0
Parse an arbitrary attribute of a given type and return it in result.
virtual void printType(Type type)
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
iterator_range< args_iterator > addArguments(TypeRange types, ArrayRef< Location > locs)
Add one argument to the argument list for each type specified in the list.
Definition Block.cpp:160
Operation & front()
Definition Block.h:153
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:244
BlockArgListType getArguments()
Definition Block.h:87
static BoolAttr get(MLIRContext *context, bool value)
MLIRContext * getContext() const
Definition Builders.h:56
This is a utility class for mapping one set of IR entities to another.
Definition IRMapping.h:26
This class defines the main interface for locations in MLIR and acts as a non-nullable wrapper around...
Definition Location.h:76
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
This class provides a mutable adaptor for a range of operands.
Definition ValueRange.h:118
void append(ValueRange values)
Append the given values to the range.
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
RAII guard to reset the insertion point of the builder when destroyed.
Definition Builders.h:348
This class helps build Operations.
Definition Builders.h:207
Block * createBlock(Region *parent, Region::iterator insertPt={}, TypeRange argTypes={}, ArrayRef< Location > locs={})
Add new block with 'argTypes' arguments and set the insertion point to the end of it.
Definition Builders.cpp:430
void setInsertionPointToStart(Block *block)
Sets the insertion point to the start of the specified block.
Definition Builders.h:431
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
Location getLoc()
The source location the operation was defined or derived from.
This provides public APIs that all operations should have.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
OperandRange operand_range
Definition Operation.h:371
void setAttr(StringAttr name, Attribute value)
If the an attribute exists with the specified name, change it to the new value.
Definition Operation.h:582
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:378
result_range getResults()
Definition Operation.h:415
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class represents a point being branched from in the methods of the RegionBranchOpInterface.
bool isParent() const
Returns true if branching from the parent op.
This class represents a successor of a region.
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
bool hasOneBlock()
Return true if this region has exactly one block.
Definition Region.h:68
RewritePatternSet & add(ConstructorArg &&arg, ConstructorArgs &&...args)
Add an instance of each of the pattern types 'Ts' to the pattern list with the given arguments.
virtual void replaceOp(Operation *op, ValueRange newValues)
Replace the results of the given (original) operation with the specified list of values (replacements...
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
virtual void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, ValueRange argValues={})
Inline the operations of block 'source' into block 'dest' before the given position.
void modifyOpInPlace(Operation *root, CallableT &&callable)
This method is a utility wrapper around an in-place modification of an operation.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
bool isIntOrIndexOrFloat() const
Return true if this is an integer (of any signedness), index, or float type.
Definition Types.cpp:120
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
Type getType() const
Return the type of this value.
Definition Value.h:105
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static ConstantIndexOp create(OpBuilder &builder, Location location, int64_t value)
Definition ArithOps.cpp:359
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< int32_t > content)
mlir::Value getAccVar(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation.
Definition OpenACC.cpp:4854
mlir::Value getVar(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation.
Definition OpenACC.cpp:4823
mlir::TypedValue< mlir::acc::PointerLikeType > getAccPtr(mlir::Operation *accDataClauseOp)
Used to obtain the accVar from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4842
std::optional< mlir::acc::DataClause > getDataClause(mlir::Operation *accDataEntryOp)
Used to obtain the dataClause from a data entry operation.
Definition OpenACC.cpp:4927
mlir::MutableOperandRange getMutableDataOperands(mlir::Operation *accOp)
Used to get a mutable range iterating over the data operands.
Definition OpenACC.cpp:4955
mlir::SmallVector< mlir::Value > getBounds(mlir::Operation *accDataClauseOp)
Used to obtain bounds from an acc data clause operation.
Definition OpenACC.cpp:4872
std::optional< ClauseDefaultValue > getDefaultAttr(mlir::Operation *op)
Looks for an OpenACC default attribute on the current operation op or in a parent operation which enc...
mlir::ValueRange getDataOperands(mlir::Operation *accOp)
Used to get an immutable range iterating over the data operands.
Definition OpenACC.cpp:4945
std::optional< llvm::StringRef > getVarName(mlir::Operation *accOp)
Used to obtain the name from an acc operation.
Definition OpenACC.cpp:4916
bool getImplicitFlag(mlir::Operation *accDataEntryOp)
Used to find out whether data operation is implicit.
Definition OpenACC.cpp:4937
mlir::SymbolRefAttr getRecipe(mlir::Operation *accOp)
Used to get the recipe attribute from a data clause operation.
Definition OpenACC.cpp:4964
mlir::SmallVector< mlir::Value > getAsyncOperands(mlir::Operation *accDataClauseOp)
Used to obtain async operands from an acc data clause operation.
Definition OpenACC.cpp:4887
bool isMappableType(mlir::Type type)
Used to check whether the provided type implements the MappableType interface.
Definition OpenACC.h:166
mlir::Value getVarPtrPtr(mlir::Operation *accDataClauseOp)
Used to obtain the varPtrPtr from a data clause operation.
Definition OpenACC.cpp:4862
static constexpr StringLiteral getVarNameAttrName()
Definition OpenACC.h:203
mlir::ArrayAttr getAsyncOnly(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4909
mlir::Type getVarType(mlir::Operation *accDataClauseOp)
Used to obtains the varType from a data clause operation which records the type of variable.
Definition OpenACC.cpp:4831
mlir::TypedValue< mlir::acc::PointerLikeType > getVarPtr(mlir::Operation *accDataClauseOp)
Used to obtain the var from a data clause operation if it implements PointerLikeType.
Definition OpenACC.cpp:4809
mlir::ArrayAttr getAsyncOperandsDeviceType(mlir::Operation *accDataClauseOp)
Returns an array of acc:DeviceTypeAttr attributes attached to an acc data clause operation,...
Definition OpenACC.cpp:4901
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:573
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:497
detail::DenseArrayAttrImpl< int32_t > DenseI32ArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
This is the representation of an operand reference.
OpRewritePattern is a wrapper around RewritePattern that allows for matching and rewriting against an...
This represents an operation in an abstracted form, suitable for use with the builder APIs.
Region * addRegion()
Create a region that should be attached to the operation.