MLIR 22.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
23#include "mlir/IR/SymbolTable.h"
25
26#include "llvm/ADT/ArrayRef.h"
27#include "llvm/ADT/PostOrderIterator.h"
28#include "llvm/ADT/STLExtras.h"
29#include "llvm/ADT/STLForwardCompat.h"
30#include "llvm/ADT/SmallString.h"
31#include "llvm/ADT/StringExtras.h"
32#include "llvm/ADT/StringRef.h"
33#include "llvm/ADT/TypeSwitch.h"
34#include "llvm/ADT/bit.h"
35#include "llvm/Support/InterleavedRange.h"
36#include <cstddef>
37#include <iterator>
38#include <optional>
39#include <variant>
40
41#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
42#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
43#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
45
46using namespace mlir;
47using namespace mlir::omp;
48
51 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
52}
53
56 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
57}
58
61 return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
62}
63
64namespace {
65struct MemRefPointerLikeModel
66 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
67 MemRefType> {
68 Type getElementType(Type pointer) const {
69 return llvm::cast<MemRefType>(pointer).getElementType();
70 }
71};
72
73struct LLVMPointerPointerLikeModel
74 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
75 LLVM::LLVMPointerType> {
76 Type getElementType(Type pointer) const { return Type(); }
77};
78} // namespace
79
80/// Generate a name of a canonical loop nest of the format
81/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
82/// argument index of an operation that has multiple regions, if the operation
83/// has multiple regions.
84/// `_s<idx>` identifies the position of an operation within a region, where
85/// only operations that may potentially contain loops ("container operations"
86/// i.e. have region arguments) are counted. Again, it is omitted if there is
87/// only one such operation in a region. If there are canonical loops nested
88/// inside each other, also may also use the format `_d<num>` where <num> is the
89/// nesting depth of the loop.
90///
91/// The generated name is a best-effort to make canonical loop unique within an
92/// SSA namespace. This also means that regions with IsolatedFromAbove property
93/// do not consider any parents or siblings.
94static std::string generateLoopNestingName(StringRef prefix,
95 CanonicalLoopOp op) {
96 struct Component {
97 /// If true, this component describes a region operand of an operation (the
98 /// operand's owner) If false, this component describes an operation located
99 /// in a parent region
100 bool isRegionArgOfOp;
101 bool skip = false;
102 bool isUnique = false;
103
104 size_t idx;
105 Operation *op;
106 Region *parentRegion;
107 size_t loopDepth;
108
109 Operation *&getOwnerOp() {
110 assert(isRegionArgOfOp && "Must describe a region operand");
111 return op;
112 }
113 size_t &getArgIdx() {
114 assert(isRegionArgOfOp && "Must describe a region operand");
115 return idx;
116 }
117
118 Operation *&getContainerOp() {
119 assert(!isRegionArgOfOp && "Must describe a operation of a region");
120 return op;
121 }
122 size_t &getOpPos() {
123 assert(!isRegionArgOfOp && "Must describe a operation of a region");
124 return idx;
125 }
126 bool isLoopOp() const {
127 assert(!isRegionArgOfOp && "Must describe a operation of a region");
128 return isa<CanonicalLoopOp>(op);
129 }
130 Region *&getParentRegion() {
131 assert(!isRegionArgOfOp && "Must describe a operation of a region");
132 return parentRegion;
133 }
134 size_t &getLoopDepth() {
135 assert(!isRegionArgOfOp && "Must describe a operation of a region");
136 return loopDepth;
137 }
138
139 void skipIf(bool v = true) { skip = skip || v; }
140 };
141
142 // List of ancestors, from inner to outer.
143 // Alternates between
144 // * region argument of an operation
145 // * operation within a region
146 SmallVector<Component> components;
147
148 // Gather a list of parent regions and operations, and the position within
149 // their parent
150 Operation *o = op.getOperation();
151 while (o) {
152 // Operation within a region
153 Region *r = o->getParentRegion();
154 if (!r)
155 break;
156
157 llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
158 size_t idx = 0;
159 bool found = false;
160 size_t sequentialIdx = -1;
161 bool isOnlyContainerOp = true;
162 for (Block *b : traversal) {
163 for (Operation &op : *b) {
164 if (&op == o && !found) {
165 sequentialIdx = idx;
166 found = true;
167 }
168 if (op.getNumRegions()) {
169 idx += 1;
170 if (idx > 1)
171 isOnlyContainerOp = false;
172 }
173 if (found && !isOnlyContainerOp)
174 break;
175 }
176 }
177
178 Component &containerOpInRegion = components.emplace_back();
179 containerOpInRegion.isRegionArgOfOp = false;
180 containerOpInRegion.isUnique = isOnlyContainerOp;
181 containerOpInRegion.getContainerOp() = o;
182 containerOpInRegion.getOpPos() = sequentialIdx;
183 containerOpInRegion.getParentRegion() = r;
184
185 Operation *parent = r->getParentOp();
186
187 // Region argument of an operation
188 Component &regionArgOfOperation = components.emplace_back();
189 regionArgOfOperation.isRegionArgOfOp = true;
190 regionArgOfOperation.isUnique = true;
191 regionArgOfOperation.getArgIdx() = 0;
192 regionArgOfOperation.getOwnerOp() = parent;
193
194 // The IsolatedFromAbove trait of the parent operation implies that each
195 // individual region argument has its own separate namespace, so no
196 // ambiguity.
197 if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
198 break;
199
200 // Component only needed if operation has multiple region operands. Region
201 // arguments may be optional, but we currently do not consider this.
202 if (parent->getRegions().size() > 1) {
203 auto getRegionIndex = [](Operation *o, Region *r) {
204 for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
205 if (&region == r)
206 return idx;
207 }
208 llvm_unreachable("Region not child of its parent operation");
209 };
210 regionArgOfOperation.isUnique = false;
211 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
212 }
213
214 // next parent
215 o = parent;
216 }
217
218 // Determine whether a region-argument component is not needed
219 for (Component &c : components)
220 c.skipIf(c.isRegionArgOfOp && c.isUnique);
221
222 // Find runs of nested loops and determine each loop's depth in the loop nest
223 size_t numSurroundingLoops = 0;
224 for (Component &c : llvm::reverse(components)) {
225 if (c.skip)
226 continue;
227
228 // non-skipped multi-argument operands interrupt the loop nest
229 if (c.isRegionArgOfOp) {
230 numSurroundingLoops = 0;
231 continue;
232 }
233
234 // Multiple loops in a region means each of them is the outermost loop of a
235 // new loop nest
236 if (!c.isUnique)
237 numSurroundingLoops = 0;
238
239 c.getLoopDepth() = numSurroundingLoops;
240
241 // Next loop is surrounded by one more loop
242 if (isa<CanonicalLoopOp>(c.getContainerOp()))
243 numSurroundingLoops += 1;
244 }
245
246 // In loop nests, skip all but the innermost loop that contains the depth
247 // number
248 bool isLoopNest = false;
249 for (Component &c : components) {
250 if (c.skip || c.isRegionArgOfOp)
251 continue;
252
253 if (!isLoopNest && c.getLoopDepth() >= 1) {
254 // Innermost loop of a loop nest of at least two loops
255 isLoopNest = true;
256 } else if (isLoopNest) {
257 // Non-innermost loop of a loop nest
258 c.skipIf(c.isUnique);
259
260 // If there is no surrounding loop left, this must have been the outermost
261 // loop; leave loop-nest mode for the next iteration
262 if (c.getLoopDepth() == 0)
263 isLoopNest = false;
264 }
265 }
266
267 // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
268 // we mark them as skipped only after computing loop nests)
269 for (Component &c : components)
270 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
271 !isa<CanonicalLoopOp>(c.getContainerOp()));
272
273 // Components can be skipped if they are already disambiguated by their parent
274 // (or does not have a parent)
275 bool newRegion = true;
276 for (Component &c : llvm::reverse(components)) {
277 c.skipIf(newRegion && c.isUnique);
278
279 // non-skipped components disambiguate unique children
280 if (!c.skip)
281 newRegion = true;
282
283 // ...except canonical loops that need a suffix for each nest
284 if (!c.isRegionArgOfOp && c.getContainerOp())
285 newRegion = false;
286 }
287
288 // Compile the nesting name string
289 SmallString<64> Name{prefix};
290 llvm::raw_svector_ostream NameOS(Name);
291 for (auto &c : llvm::reverse(components)) {
292 if (c.skip)
293 continue;
294
295 if (c.isRegionArgOfOp)
296 NameOS << "_r" << c.getArgIdx();
297 else if (c.getLoopDepth() >= 1)
298 NameOS << "_d" << c.getLoopDepth();
299 else
300 NameOS << "_s" << c.getOpPos();
301 }
302
303 return NameOS.str().str();
304}
305
306void OpenMPDialect::initialize() {
307 addOperations<
308#define GET_OP_LIST
309#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
310 >();
311 addAttributes<
312#define GET_ATTRDEF_LIST
313#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
314 >();
315 addTypes<
316#define GET_TYPEDEF_LIST
317#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
318 >();
319
320 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
321
322 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
323 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
324 *getContext());
325
326 // Attach default offload module interface to module op to access
327 // offload functionality through
328 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
329 *getContext());
330
331 // Attach default declare target interfaces to operations which can be marked
332 // as declare target (Global Operations and Functions/Subroutines in dialects
333 // that Fortran (or other languages that lower to MLIR) translates too
334 mlir::LLVM::GlobalOp::attachInterface<
336 *getContext());
337 mlir::LLVM::LLVMFuncOp::attachInterface<
339 *getContext());
340 mlir::func::FuncOp::attachInterface<
342}
343
344//===----------------------------------------------------------------------===//
345// Parser and printer for Allocate Clause
346//===----------------------------------------------------------------------===//
347
348/// Parse an allocate clause with allocators and a list of operands with types.
349///
350/// allocate-operand-list :: = allocate-operand |
351/// allocator-operand `,` allocate-operand-list
352/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
353/// ssa-id-and-type ::= ssa-id `:` type
354static ParseResult parseAllocateAndAllocator(
355 OpAsmParser &parser,
357 SmallVectorImpl<Type> &allocateTypes,
359 SmallVectorImpl<Type> &allocatorTypes) {
360
361 return parser.parseCommaSeparatedList([&]() {
363 Type type;
364 if (parser.parseOperand(operand) || parser.parseColonType(type))
365 return failure();
366 allocatorVars.push_back(operand);
367 allocatorTypes.push_back(type);
368 if (parser.parseArrow())
369 return failure();
370 if (parser.parseOperand(operand) || parser.parseColonType(type))
371 return failure();
372
373 allocateVars.push_back(operand);
374 allocateTypes.push_back(type);
375 return success();
376 });
377}
378
379/// Print allocate clause
381 OperandRange allocateVars,
382 TypeRange allocateTypes,
383 OperandRange allocatorVars,
384 TypeRange allocatorTypes) {
385 for (unsigned i = 0; i < allocateVars.size(); ++i) {
386 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
387 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
388 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
389 }
390}
391
392//===----------------------------------------------------------------------===//
393// Parser and printer for a clause attribute (StringEnumAttr)
394//===----------------------------------------------------------------------===//
395
396template <typename ClauseAttr>
397static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
398 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
399 StringRef enumStr;
400 SMLoc loc = parser.getCurrentLocation();
401 if (parser.parseKeyword(&enumStr))
402 return failure();
403 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
404 attr = ClauseAttr::get(parser.getContext(), *enumValue);
405 return success();
406 }
407 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
408}
409
410template <typename ClauseAttr>
411static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
412 p << stringifyEnum(attr.getValue());
413}
414
415//===----------------------------------------------------------------------===//
416// Parser and printer for Linear Clause
417//===----------------------------------------------------------------------===//
418
419/// linear ::= `linear` `(` linear-list `)`
420/// linear-list := linear-val | linear-val linear-list
421/// linear-val := ssa-id-and-type `=` ssa-id-and-type
422static ParseResult parseLinearClause(
423 OpAsmParser &parser,
425 SmallVectorImpl<Type> &linearTypes,
427 return parser.parseCommaSeparatedList([&]() {
429 Type type;
431 if (parser.parseOperand(var) || parser.parseEqual() ||
432 parser.parseOperand(stepVar) || parser.parseColonType(type))
433 return failure();
434
435 linearVars.push_back(var);
436 linearTypes.push_back(type);
437 linearStepVars.push_back(stepVar);
438 return success();
439 });
440}
441
442/// Print Linear Clause
444 ValueRange linearVars, TypeRange linearTypes,
445 ValueRange linearStepVars) {
446 size_t linearVarsSize = linearVars.size();
447 for (unsigned i = 0; i < linearVarsSize; ++i) {
448 std::string separator = i == linearVarsSize - 1 ? "" : ", ";
449 p << linearVars[i];
450 if (linearStepVars.size() > i)
451 p << " = " << linearStepVars[i];
452 p << " : " << linearVars[i].getType() << separator;
453 }
454}
455
456//===----------------------------------------------------------------------===//
457// Verifier for Nontemporal Clause
458//===----------------------------------------------------------------------===//
459
460static LogicalResult verifyNontemporalClause(Operation *op,
461 OperandRange nontemporalVars) {
462
463 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
464 DenseSet<Value> nontemporalItems;
465 for (const auto &it : nontemporalVars)
466 if (!nontemporalItems.insert(it).second)
467 return op->emitOpError() << "nontemporal variable used more than once";
468
469 return success();
470}
471
472//===----------------------------------------------------------------------===//
473// Parser, verifier and printer for Aligned Clause
474//===----------------------------------------------------------------------===//
475static LogicalResult verifyAlignedClause(Operation *op,
476 std::optional<ArrayAttr> alignments,
477 OperandRange alignedVars) {
478 // Check if number of alignment values equals to number of aligned variables
479 if (!alignedVars.empty()) {
480 if (!alignments || alignments->size() != alignedVars.size())
481 return op->emitOpError()
482 << "expected as many alignment values as aligned variables";
483 } else {
484 if (alignments)
485 return op->emitOpError() << "unexpected alignment values attribute";
486 return success();
487 }
488
489 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
490 DenseSet<Value> alignedItems;
491 for (auto it : alignedVars)
492 if (!alignedItems.insert(it).second)
493 return op->emitOpError() << "aligned variable used more than once";
494
495 if (!alignments)
496 return success();
497
498 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
499 for (unsigned i = 0; i < (*alignments).size(); ++i) {
500 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
501 if (intAttr.getValue().sle(0))
502 return op->emitOpError() << "alignment should be greater than 0";
503 } else {
504 return op->emitOpError() << "expected integer alignment";
505 }
506 }
507
508 return success();
509}
510
511/// aligned ::= `aligned` `(` aligned-list `)`
512/// aligned-list := aligned-val | aligned-val aligned-list
513/// aligned-val := ssa-id-and-type `->` alignment
514static ParseResult
517 SmallVectorImpl<Type> &alignedTypes,
518 ArrayAttr &alignmentsAttr) {
519 SmallVector<Attribute> alignmentVec;
520 if (failed(parser.parseCommaSeparatedList([&]() {
521 if (parser.parseOperand(alignedVars.emplace_back()) ||
522 parser.parseColonType(alignedTypes.emplace_back()) ||
523 parser.parseArrow() ||
524 parser.parseAttribute(alignmentVec.emplace_back())) {
525 return failure();
526 }
527 return success();
528 })))
529 return failure();
530 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
531 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
532 return success();
533}
534
535/// Print Aligned Clause
537 ValueRange alignedVars, TypeRange alignedTypes,
538 std::optional<ArrayAttr> alignments) {
539 for (unsigned i = 0; i < alignedVars.size(); ++i) {
540 if (i != 0)
541 p << ", ";
542 p << alignedVars[i] << " : " << alignedVars[i].getType();
543 p << " -> " << (*alignments)[i];
544 }
545}
546
547//===----------------------------------------------------------------------===//
548// Parser, printer and verifier for Schedule Clause
549//===----------------------------------------------------------------------===//
550
551static ParseResult
553 SmallVectorImpl<SmallString<12>> &modifiers) {
554 if (modifiers.size() > 2)
555 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
556 for (const auto &mod : modifiers) {
557 // Translate the string. If it has no value, then it was not a valid
558 // modifier!
559 auto symbol = symbolizeScheduleModifier(mod);
560 if (!symbol)
561 return parser.emitError(parser.getNameLoc())
562 << " unknown modifier type: " << mod;
563 }
564
565 // If we have one modifier that is "simd", then stick a "none" modiifer in
566 // index 0.
567 if (modifiers.size() == 1) {
568 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
569 modifiers.push_back(modifiers[0]);
570 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
571 }
572 } else if (modifiers.size() == 2) {
573 // If there are two modifier:
574 // First modifier should not be simd, second one should be simd
575 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
576 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
577 return parser.emitError(parser.getNameLoc())
578 << " incorrect modifier order";
579 }
580 return success();
581}
582
583/// schedule ::= `schedule` `(` sched-list `)`
584/// sched-list ::= sched-val | sched-val sched-list |
585/// sched-val `,` sched-modifier
586/// sched-val ::= sched-with-chunk | sched-wo-chunk
587/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
588/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
589/// sched-wo-chunk ::= `auto` | `runtime`
590/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
591/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
592static ParseResult
593parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
594 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
595 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
596 Type &chunkType) {
597 StringRef keyword;
598 if (parser.parseKeyword(&keyword))
599 return failure();
600 std::optional<mlir::omp::ClauseScheduleKind> schedule =
601 symbolizeClauseScheduleKind(keyword);
602 if (!schedule)
603 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
604
605 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
606 switch (*schedule) {
607 case ClauseScheduleKind::Static:
608 case ClauseScheduleKind::Dynamic:
609 case ClauseScheduleKind::Guided:
610 if (succeeded(parser.parseOptionalEqual())) {
611 chunkSize = OpAsmParser::UnresolvedOperand{};
612 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
613 return failure();
614 } else {
615 chunkSize = std::nullopt;
616 }
617 break;
618 case ClauseScheduleKind::Auto:
619 case ClauseScheduleKind::Runtime:
620 chunkSize = std::nullopt;
621 }
622
623 // If there is a comma, we have one or more modifiers..
625 while (succeeded(parser.parseOptionalComma())) {
626 StringRef mod;
627 if (parser.parseKeyword(&mod))
628 return failure();
629 modifiers.push_back(mod);
630 }
631
632 if (verifyScheduleModifiers(parser, modifiers))
633 return failure();
634
635 if (!modifiers.empty()) {
636 SMLoc loc = parser.getCurrentLocation();
637 if (std::optional<ScheduleModifier> mod =
638 symbolizeScheduleModifier(modifiers[0])) {
639 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
640 } else {
641 return parser.emitError(loc, "invalid schedule modifier");
642 }
643 // Only SIMD attribute is allowed here!
644 if (modifiers.size() > 1) {
645 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
646 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
647 }
648 }
649
650 return success();
651}
652
653/// Print schedule clause
655 ClauseScheduleKindAttr scheduleKind,
656 ScheduleModifierAttr scheduleMod,
657 UnitAttr scheduleSimd, Value scheduleChunk,
658 Type scheduleChunkType) {
659 p << stringifyClauseScheduleKind(scheduleKind.getValue());
660 if (scheduleChunk)
661 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
662 if (scheduleMod)
663 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
664 if (scheduleSimd)
665 p << ", simd";
666}
667
668//===----------------------------------------------------------------------===//
669// Parser and printer for Order Clause
670//===----------------------------------------------------------------------===//
671
672// order ::= `order` `(` [order-modifier ':'] concurrent `)`
673// order-modifier ::= reproducible | unconstrained
674static ParseResult parseOrderClause(OpAsmParser &parser,
675 ClauseOrderKindAttr &order,
676 OrderModifierAttr &orderMod) {
677 StringRef enumStr;
678 SMLoc loc = parser.getCurrentLocation();
679 if (parser.parseKeyword(&enumStr))
680 return failure();
681 if (std::optional<OrderModifier> enumValue =
682 symbolizeOrderModifier(enumStr)) {
683 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
684 if (parser.parseOptionalColon())
685 return failure();
686 loc = parser.getCurrentLocation();
687 if (parser.parseKeyword(&enumStr))
688 return failure();
689 }
690 if (std::optional<ClauseOrderKind> enumValue =
691 symbolizeClauseOrderKind(enumStr)) {
692 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
693 return success();
694 }
695 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
696}
697
699 ClauseOrderKindAttr order,
700 OrderModifierAttr orderMod) {
701 if (orderMod)
702 p << stringifyOrderModifier(orderMod.getValue()) << ":";
703 if (order)
704 p << stringifyClauseOrderKind(order.getValue());
705}
706
707template <typename ClauseTypeAttr, typename ClauseType>
708static ParseResult
709parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
710 std::optional<OpAsmParser::UnresolvedOperand> &operand,
711 Type &operandType,
712 std::optional<ClauseType> (*symbolizeClause)(StringRef),
713 StringRef clauseName) {
714 StringRef enumStr;
715 if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
716 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
717 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
718 if (parser.parseComma())
719 return failure();
720 } else {
721 return parser.emitError(parser.getCurrentLocation())
722 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
723 ;
724 }
725 }
726
728 if (succeeded(parser.parseOperand(var))) {
729 operand = var;
730 } else {
731 return parser.emitError(parser.getCurrentLocation())
732 << "expected " << clauseName << " operand";
733 }
734
735 if (operand.has_value()) {
736 if (parser.parseColonType(operandType))
737 return failure();
738 }
739
740 return success();
741}
742
743template <typename ClauseTypeAttr, typename ClauseType>
744static void
746 ClauseTypeAttr prescriptiveness, Value operand,
747 mlir::Type operandType,
748 StringRef (*stringifyClauseType)(ClauseType)) {
749
750 if (prescriptiveness)
751 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
752
753 if (operand)
754 p << operand << ": " << operandType;
755}
756
757//===----------------------------------------------------------------------===//
758// Parser and printer for grainsize Clause
759//===----------------------------------------------------------------------===//
760
761// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
762static ParseResult
763parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
764 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
765 Type &grainsizeType) {
767 parser, grainsizeMod, grainsize, grainsizeType,
768 &symbolizeClauseGrainsizeType, "grainsize");
769}
770
772 ClauseGrainsizeTypeAttr grainsizeMod,
773 Value grainsize, mlir::Type grainsizeType) {
775 p, op, grainsizeMod, grainsize, grainsizeType,
776 &stringifyClauseGrainsizeType);
777}
778
779//===----------------------------------------------------------------------===//
780// Parser and printer for num_tasks Clause
781//===----------------------------------------------------------------------===//
782
783// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
784static ParseResult
785parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
786 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
787 Type &numTasksType) {
789 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
790 "num_tasks");
791}
792
794 ClauseNumTasksTypeAttr numTasksMod,
795 Value numTasks, mlir::Type numTasksType) {
797 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
798}
799
800//===----------------------------------------------------------------------===//
801// Parsers for operations including clauses that define entry block arguments.
802//===----------------------------------------------------------------------===//
803
804namespace {
805struct MapParseArgs {
806 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
807 SmallVectorImpl<Type> &types;
808 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
809 SmallVectorImpl<Type> &types)
810 : vars(vars), types(types) {}
811};
812struct PrivateParseArgs {
813 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
814 llvm::SmallVectorImpl<Type> &types;
815 ArrayAttr &syms;
816 UnitAttr &needsBarrier;
817 DenseI64ArrayAttr *mapIndices;
818 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
819 SmallVectorImpl<Type> &types, ArrayAttr &syms,
820 UnitAttr &needsBarrier,
821 DenseI64ArrayAttr *mapIndices = nullptr)
822 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
823 mapIndices(mapIndices) {}
824};
825
826struct ReductionParseArgs {
827 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
828 SmallVectorImpl<Type> &types;
829 DenseBoolArrayAttr &byref;
830 ArrayAttr &syms;
831 ReductionModifierAttr *modifier;
832 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
833 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
834 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
835 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
836};
837
838struct AllRegionParseArgs {
839 std::optional<MapParseArgs> hasDeviceAddrArgs;
840 std::optional<MapParseArgs> hostEvalArgs;
841 std::optional<ReductionParseArgs> inReductionArgs;
842 std::optional<MapParseArgs> mapArgs;
843 std::optional<PrivateParseArgs> privateArgs;
844 std::optional<ReductionParseArgs> reductionArgs;
845 std::optional<ReductionParseArgs> taskReductionArgs;
846 std::optional<MapParseArgs> useDeviceAddrArgs;
847 std::optional<MapParseArgs> useDevicePtrArgs;
848};
849} // namespace
850
851static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
852 return "private_barrier";
853}
854
855static ParseResult parseClauseWithRegionArgs(
856 OpAsmParser &parser,
860 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
861 DenseBoolArrayAttr *byref = nullptr,
862 ReductionModifierAttr *modifier = nullptr,
863 UnitAttr *needsBarrier = nullptr) {
865 SmallVector<int64_t> mapIndicesVec;
866 SmallVector<bool> isByRefVec;
867 unsigned regionArgOffset = regionPrivateArgs.size();
868
869 if (parser.parseLParen())
870 return failure();
871
872 if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
873 StringRef enumStr;
874 if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
875 parser.parseComma())
876 return failure();
877 std::optional<ReductionModifier> enumValue =
878 symbolizeReductionModifier(enumStr);
879 if (!enumValue.has_value())
880 return failure();
881 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
882 if (!*modifier)
883 return failure();
884 }
885
886 if (parser.parseCommaSeparatedList([&]() {
887 if (byref)
888 isByRefVec.push_back(
889 parser.parseOptionalKeyword("byref").succeeded());
890
891 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
892 return failure();
893
894 if (parser.parseOperand(operands.emplace_back()) ||
895 parser.parseArrow() ||
896 parser.parseArgument(regionPrivateArgs.emplace_back()))
897 return failure();
898
899 if (mapIndices) {
900 if (parser.parseOptionalLSquare().succeeded()) {
901 if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
902 parser.parseInteger(mapIndicesVec.emplace_back()) ||
903 parser.parseRSquare())
904 return failure();
905 } else {
906 mapIndicesVec.push_back(-1);
907 }
908 }
909
910 return success();
911 }))
912 return failure();
913
914 if (parser.parseColon())
915 return failure();
916
917 if (parser.parseCommaSeparatedList([&]() {
918 if (parser.parseType(types.emplace_back()))
919 return failure();
920
921 return success();
922 }))
923 return failure();
924
925 if (operands.size() != types.size())
926 return failure();
927
928 if (parser.parseRParen())
929 return failure();
930
931 if (needsBarrier) {
933 .succeeded())
934 *needsBarrier = mlir::UnitAttr::get(parser.getContext());
935 }
936
937 auto *argsBegin = regionPrivateArgs.begin();
938 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
939 argsBegin + regionArgOffset + types.size());
940 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
941 prv.type = type;
942 }
943
944 if (symbols) {
945 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
946 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
947 }
948
949 if (!mapIndicesVec.empty())
950 *mapIndices =
951 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
952
953 if (byref)
954 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
955
956 return success();
957}
958
959static ParseResult parseBlockArgClause(
960 OpAsmParser &parser,
962 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
963 if (succeeded(parser.parseOptionalKeyword(keyword))) {
964 if (!mapArgs)
965 return failure();
966
967 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
968 entryBlockArgs)))
969 return failure();
970 }
971 return success();
972}
973
974static ParseResult parseBlockArgClause(
975 OpAsmParser &parser,
977 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
978 if (succeeded(parser.parseOptionalKeyword(keyword))) {
979 if (!privateArgs)
980 return failure();
981
982 if (failed(parseClauseWithRegionArgs(
983 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
984 &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
985 /*modifier=*/nullptr, &privateArgs->needsBarrier)))
986 return failure();
987 }
988 return success();
989}
990
991static ParseResult parseBlockArgClause(
992 OpAsmParser &parser,
994 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
995 if (succeeded(parser.parseOptionalKeyword(keyword))) {
996 if (!reductionArgs)
997 return failure();
998 if (failed(parseClauseWithRegionArgs(
999 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1000 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1001 reductionArgs->modifier)))
1002 return failure();
1003 }
1004 return success();
1005}
1006
1007static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1008 AllRegionParseArgs args) {
1010
1011 if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1012 args.hasDeviceAddrArgs)))
1013 return parser.emitError(parser.getCurrentLocation())
1014 << "invalid `has_device_addr` format";
1015
1016 if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1017 args.hostEvalArgs)))
1018 return parser.emitError(parser.getCurrentLocation())
1019 << "invalid `host_eval` format";
1020
1021 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1022 args.inReductionArgs)))
1023 return parser.emitError(parser.getCurrentLocation())
1024 << "invalid `in_reduction` format";
1025
1026 if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1027 args.mapArgs)))
1028 return parser.emitError(parser.getCurrentLocation())
1029 << "invalid `map_entries` format";
1030
1031 if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1032 args.privateArgs)))
1033 return parser.emitError(parser.getCurrentLocation())
1034 << "invalid `private` format";
1035
1036 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1037 args.reductionArgs)))
1038 return parser.emitError(parser.getCurrentLocation())
1039 << "invalid `reduction` format";
1040
1041 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1042 args.taskReductionArgs)))
1043 return parser.emitError(parser.getCurrentLocation())
1044 << "invalid `task_reduction` format";
1045
1046 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1047 args.useDeviceAddrArgs)))
1048 return parser.emitError(parser.getCurrentLocation())
1049 << "invalid `use_device_addr` format";
1050
1051 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1052 args.useDevicePtrArgs)))
1053 return parser.emitError(parser.getCurrentLocation())
1054 << "invalid `use_device_addr` format";
1055
1056 return parser.parseRegion(region, entryBlockArgs);
1057}
1058
1059// These parseXyz functions correspond to the custom<Xyz> definitions
1060// in the .td file(s).
1061static ParseResult parseTargetOpRegion(
1062 OpAsmParser &parser, Region &region,
1064 SmallVectorImpl<Type> &hasDeviceAddrTypes,
1066 SmallVectorImpl<Type> &hostEvalTypes,
1068 SmallVectorImpl<Type> &inReductionTypes,
1069 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1071 SmallVectorImpl<Type> &mapTypes,
1073 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1074 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1075 AllRegionParseArgs args;
1076 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1077 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1078 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1079 inReductionByref, inReductionSyms);
1080 args.mapArgs.emplace(mapVars, mapTypes);
1081 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1082 privateNeedsBarrier, &privateMaps);
1083 return parseBlockArgRegion(parser, region, args);
1084}
1085
1087 OpAsmParser &parser, Region &region,
1089 SmallVectorImpl<Type> &inReductionTypes,
1090 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1092 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1093 UnitAttr &privateNeedsBarrier) {
1094 AllRegionParseArgs args;
1095 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1096 inReductionByref, inReductionSyms);
1097 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1098 privateNeedsBarrier);
1099 return parseBlockArgRegion(parser, region, args);
1100}
1101
1103 OpAsmParser &parser, Region &region,
1105 SmallVectorImpl<Type> &inReductionTypes,
1106 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1108 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1109 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1111 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1112 ArrayAttr &reductionSyms) {
1113 AllRegionParseArgs args;
1114 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1115 inReductionByref, inReductionSyms);
1116 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1117 privateNeedsBarrier);
1118 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1119 reductionSyms, &reductionMod);
1120 return parseBlockArgRegion(parser, region, args);
1121}
1122
1123static ParseResult parsePrivateRegion(
1124 OpAsmParser &parser, Region &region,
1126 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1127 UnitAttr &privateNeedsBarrier) {
1128 AllRegionParseArgs args;
1129 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1130 privateNeedsBarrier);
1131 return parseBlockArgRegion(parser, region, args);
1132}
1133
1135 OpAsmParser &parser, Region &region,
1137 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1138 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1140 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1141 ArrayAttr &reductionSyms) {
1142 AllRegionParseArgs args;
1143 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1144 privateNeedsBarrier);
1145 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1146 reductionSyms, &reductionMod);
1147 return parseBlockArgRegion(parser, region, args);
1148}
1149
1150static ParseResult parseTaskReductionRegion(
1151 OpAsmParser &parser, Region &region,
1153 SmallVectorImpl<Type> &taskReductionTypes,
1154 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1155 AllRegionParseArgs args;
1156 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1157 taskReductionByref, taskReductionSyms);
1158 return parseBlockArgRegion(parser, region, args);
1159}
1160
1162 OpAsmParser &parser, Region &region,
1164 SmallVectorImpl<Type> &useDeviceAddrTypes,
1166 SmallVectorImpl<Type> &useDevicePtrTypes) {
1167 AllRegionParseArgs args;
1168 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1169 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1170 return parseBlockArgRegion(parser, region, args);
1171}
1172
1173//===----------------------------------------------------------------------===//
1174// Printers for operations including clauses that define entry block arguments.
1175//===----------------------------------------------------------------------===//
1176
1177namespace {
1178struct MapPrintArgs {
1179 ValueRange vars;
1180 TypeRange types;
1181 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1182};
1183struct PrivatePrintArgs {
1184 ValueRange vars;
1185 TypeRange types;
1186 ArrayAttr syms;
1187 UnitAttr needsBarrier;
1188 DenseI64ArrayAttr mapIndices;
1189 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1190 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1191 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1192 mapIndices(mapIndices) {}
1193};
1194struct ReductionPrintArgs {
1195 ValueRange vars;
1196 TypeRange types;
1197 DenseBoolArrayAttr byref;
1198 ArrayAttr syms;
1199 ReductionModifierAttr modifier;
1200 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1201 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1202 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1203};
1204struct AllRegionPrintArgs {
1205 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1206 std::optional<MapPrintArgs> hostEvalArgs;
1207 std::optional<ReductionPrintArgs> inReductionArgs;
1208 std::optional<MapPrintArgs> mapArgs;
1209 std::optional<PrivatePrintArgs> privateArgs;
1210 std::optional<ReductionPrintArgs> reductionArgs;
1211 std::optional<ReductionPrintArgs> taskReductionArgs;
1212 std::optional<MapPrintArgs> useDeviceAddrArgs;
1213 std::optional<MapPrintArgs> useDevicePtrArgs;
1214};
1215} // namespace
1216
1218 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1219 ValueRange argsSubrange, ValueRange operands, TypeRange types,
1220 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1221 DenseBoolArrayAttr byref = nullptr,
1222 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1223 if (argsSubrange.empty())
1224 return;
1225
1226 p << clauseName << "(";
1227
1228 if (modifier)
1229 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1230
1231 if (!symbols) {
1232 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1233 symbols = ArrayAttr::get(ctx, values);
1234 }
1235
1236 if (!mapIndices) {
1237 llvm::SmallVector<int64_t> values(operands.size(), -1);
1238 mapIndices = DenseI64ArrayAttr::get(ctx, values);
1239 }
1240
1241 if (!byref) {
1242 mlir::SmallVector<bool> values(operands.size(), false);
1243 byref = DenseBoolArrayAttr::get(ctx, values);
1244 }
1245
1246 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1247 mapIndices.asArrayRef(),
1248 byref.asArrayRef()),
1249 p, [&p](auto t) {
1250 auto [op, arg, sym, map, isByRef] = t;
1251 if (isByRef)
1252 p << "byref ";
1253 if (sym)
1254 p << sym << " ";
1255
1256 p << op << " -> " << arg;
1257
1258 if (map != -1)
1259 p << " [map_idx=" << map << "]";
1260 });
1261 p << " : ";
1262 llvm::interleaveComma(types, p);
1263 p << ") ";
1264
1265 if (needsBarrier)
1266 p << getPrivateNeedsBarrierSpelling() << " ";
1267}
1268
1270 StringRef clauseName, ValueRange argsSubrange,
1271 std::optional<MapPrintArgs> mapArgs) {
1272 if (mapArgs)
1273 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1274 mapArgs->types);
1275}
1276
1278 StringRef clauseName, ValueRange argsSubrange,
1279 std::optional<PrivatePrintArgs> privateArgs) {
1280 if (privateArgs)
1282 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1283 privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1284 /*modifier=*/nullptr, privateArgs->needsBarrier);
1285}
1286
1287static void
1288printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1289 ValueRange argsSubrange,
1290 std::optional<ReductionPrintArgs> reductionArgs) {
1291 if (reductionArgs)
1292 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1293 reductionArgs->vars, reductionArgs->types,
1294 reductionArgs->syms, /*mapIndices=*/nullptr,
1295 reductionArgs->byref, reductionArgs->modifier);
1296}
1297
1299 const AllRegionPrintArgs &args) {
1300 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1301 MLIRContext *ctx = op->getContext();
1302
1303 printBlockArgClause(p, ctx, "has_device_addr",
1304 iface.getHasDeviceAddrBlockArgs(),
1305 args.hasDeviceAddrArgs);
1306 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1307 args.hostEvalArgs);
1308 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1309 args.inReductionArgs);
1310 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1311 args.mapArgs);
1312 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1313 args.privateArgs);
1314 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1315 args.reductionArgs);
1316 printBlockArgClause(p, ctx, "task_reduction",
1317 iface.getTaskReductionBlockArgs(),
1318 args.taskReductionArgs);
1319 printBlockArgClause(p, ctx, "use_device_addr",
1320 iface.getUseDeviceAddrBlockArgs(),
1321 args.useDeviceAddrArgs);
1322 printBlockArgClause(p, ctx, "use_device_ptr",
1323 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1324
1325 p.printRegion(region, /*printEntryBlockArgs=*/false);
1326}
1327
1328// These parseXyz functions correspond to the custom<Xyz> definitions
1329// in the .td file(s).
1331 OpAsmPrinter &p, Operation *op, Region &region,
1332 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1333 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1334 ValueRange inReductionVars, TypeRange inReductionTypes,
1335 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1336 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1337 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1338 DenseI64ArrayAttr privateMaps) {
1339 AllRegionPrintArgs args;
1340 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1341 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1342 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1343 inReductionByref, inReductionSyms);
1344 args.mapArgs.emplace(mapVars, mapTypes);
1345 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1346 privateNeedsBarrier, privateMaps);
1347 printBlockArgRegion(p, op, region, args);
1348}
1349
1351 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1352 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1353 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1354 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1355 AllRegionPrintArgs args;
1356 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1357 inReductionByref, inReductionSyms);
1358 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1359 privateNeedsBarrier,
1360 /*mapIndices=*/nullptr);
1361 printBlockArgRegion(p, op, region, args);
1362}
1363
1365 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1366 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1367 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1368 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1369 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1370 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1371 ArrayAttr reductionSyms) {
1372 AllRegionPrintArgs args;
1373 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1374 inReductionByref, inReductionSyms);
1375 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1376 privateNeedsBarrier,
1377 /*mapIndices=*/nullptr);
1378 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1379 reductionSyms, reductionMod);
1380 printBlockArgRegion(p, op, region, args);
1381}
1382
1384 ValueRange privateVars, TypeRange privateTypes,
1385 ArrayAttr privateSyms,
1386 UnitAttr privateNeedsBarrier) {
1387 AllRegionPrintArgs args;
1388 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1389 privateNeedsBarrier,
1390 /*mapIndices=*/nullptr);
1391 printBlockArgRegion(p, op, region, args);
1392}
1393
1395 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1396 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1397 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1398 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1399 ArrayAttr reductionSyms) {
1400 AllRegionPrintArgs args;
1401 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1402 privateNeedsBarrier,
1403 /*mapIndices=*/nullptr);
1404 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1405 reductionSyms, reductionMod);
1406 printBlockArgRegion(p, op, region, args);
1407}
1408
1410 Region &region,
1411 ValueRange taskReductionVars,
1412 TypeRange taskReductionTypes,
1413 DenseBoolArrayAttr taskReductionByref,
1414 ArrayAttr taskReductionSyms) {
1415 AllRegionPrintArgs args;
1416 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1417 taskReductionByref, taskReductionSyms);
1418 printBlockArgRegion(p, op, region, args);
1419}
1420
1422 Region &region,
1423 ValueRange useDeviceAddrVars,
1424 TypeRange useDeviceAddrTypes,
1425 ValueRange useDevicePtrVars,
1426 TypeRange useDevicePtrTypes) {
1427 AllRegionPrintArgs args;
1428 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1429 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1430 printBlockArgRegion(p, op, region, args);
1431}
1432
1433/// Verifies Reduction Clause
1434static LogicalResult
1435verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1436 OperandRange reductionVars,
1437 std::optional<ArrayRef<bool>> reductionByref) {
1438 if (!reductionVars.empty()) {
1439 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1440 return op->emitOpError()
1441 << "expected as many reduction symbol references "
1442 "as reduction variables";
1443 if (reductionByref && reductionByref->size() != reductionVars.size())
1444 return op->emitError() << "expected as many reduction variable by "
1445 "reference attributes as reduction variables";
1446 } else {
1447 if (reductionSyms)
1448 return op->emitOpError() << "unexpected reduction symbol references";
1449 return success();
1450 }
1451
1452 // TODO: The followings should be done in
1453 // SymbolUserOpInterface::verifySymbolUses.
1454 DenseSet<Value> accumulators;
1455 for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1456 Value accum = std::get<0>(args);
1457
1458 if (!accumulators.insert(accum).second)
1459 return op->emitOpError() << "accumulator variable used more than once";
1460
1461 Type varType = accum.getType();
1462 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1463 auto decl =
1465 if (!decl)
1466 return op->emitOpError() << "expected symbol reference " << symbolRef
1467 << " to point to a reduction declaration";
1468
1469 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1470 return op->emitOpError()
1471 << "expected accumulator (" << varType
1472 << ") to be the same type as reduction declaration ("
1473 << decl.getAccumulatorType() << ")";
1474 }
1475
1476 return success();
1477}
1478
1479//===----------------------------------------------------------------------===//
1480// Parser, printer and verifier for Copyprivate
1481//===----------------------------------------------------------------------===//
1482
1483/// copyprivate-entry-list ::= copyprivate-entry
1484/// | copyprivate-entry-list `,` copyprivate-entry
1485/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1486static ParseResult parseCopyprivate(
1487 OpAsmParser &parser,
1489 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1491 if (failed(parser.parseCommaSeparatedList([&]() {
1492 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1493 parser.parseArrow() ||
1494 parser.parseAttribute(symsVec.emplace_back()) ||
1495 parser.parseColonType(copyprivateTypes.emplace_back()))
1496 return failure();
1497 return success();
1498 })))
1499 return failure();
1500 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1501 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1502 return success();
1503}
1504
1505/// Print Copyprivate clause
1507 OperandRange copyprivateVars,
1508 TypeRange copyprivateTypes,
1509 std::optional<ArrayAttr> copyprivateSyms) {
1510 if (!copyprivateSyms.has_value())
1511 return;
1512 llvm::interleaveComma(
1513 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1514 [&](const auto &args) {
1515 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1516 << std::get<2>(args);
1517 });
1518}
1519
1520/// Verifies CopyPrivate Clause
1521static LogicalResult
1523 std::optional<ArrayAttr> copyprivateSyms) {
1524 size_t copyprivateSymsSize =
1525 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1526 if (copyprivateSymsSize != copyprivateVars.size())
1527 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1528 << copyprivateVars.size()
1529 << ") and functions (= " << copyprivateSymsSize
1530 << "), both must be equal";
1531 if (!copyprivateSyms.has_value())
1532 return success();
1533
1534 for (auto copyprivateVarAndSym :
1535 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1536 auto symbolRef =
1537 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1538 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1539 funcOp;
1540 if (mlir::func::FuncOp mlirFuncOp =
1542 symbolRef))
1543 funcOp = mlirFuncOp;
1544 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1546 op, symbolRef))
1547 funcOp = llvmFuncOp;
1548
1549 auto getNumArguments = [&] {
1550 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1551 };
1552
1553 auto getArgumentType = [&](unsigned i) {
1554 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1555 *funcOp);
1556 };
1557
1558 if (!funcOp)
1559 return op->emitOpError() << "expected symbol reference " << symbolRef
1560 << " to point to a copy function";
1561
1562 if (getNumArguments() != 2)
1563 return op->emitOpError()
1564 << "expected copy function " << symbolRef << " to have 2 operands";
1565
1566 Type argTy = getArgumentType(0);
1567 if (argTy != getArgumentType(1))
1568 return op->emitOpError() << "expected copy function " << symbolRef
1569 << " arguments to have the same type";
1570
1571 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1572 if (argTy != varType)
1573 return op->emitOpError()
1574 << "expected copy function arguments' type (" << argTy
1575 << ") to be the same as copyprivate variable's type (" << varType
1576 << ")";
1577 }
1578
1579 return success();
1580}
1581
1582//===----------------------------------------------------------------------===//
1583// Parser, printer and verifier for DependVarList
1584//===----------------------------------------------------------------------===//
1585
1586/// depend-entry-list ::= depend-entry
1587/// | depend-entry-list `,` depend-entry
1588/// depend-entry ::= depend-kind `->` ssa-id `:` type
1589static ParseResult
1592 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds) {
1594 if (failed(parser.parseCommaSeparatedList([&]() {
1595 StringRef keyword;
1596 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1597 parser.parseOperand(dependVars.emplace_back()) ||
1598 parser.parseColonType(dependTypes.emplace_back()))
1599 return failure();
1600 if (std::optional<ClauseTaskDepend> keywordDepend =
1601 (symbolizeClauseTaskDepend(keyword)))
1602 kindsVec.emplace_back(
1603 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend));
1604 else
1605 return failure();
1606 return success();
1607 })))
1608 return failure();
1609 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1610 dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1611 return success();
1612}
1613
1614/// Print Depend clause
1616 OperandRange dependVars, TypeRange dependTypes,
1617 std::optional<ArrayAttr> dependKinds) {
1618
1619 for (unsigned i = 0, e = dependKinds->size(); i < e; ++i) {
1620 if (i != 0)
1621 p << ", ";
1622 p << stringifyClauseTaskDepend(
1623 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*dependKinds)[i])
1624 .getValue())
1625 << " -> " << dependVars[i] << " : " << dependTypes[i];
1626 }
1627}
1628
1629/// Verifies Depend clause
1630static LogicalResult verifyDependVarList(Operation *op,
1631 std::optional<ArrayAttr> dependKinds,
1632 OperandRange dependVars) {
1633 if (!dependVars.empty()) {
1634 if (!dependKinds || dependKinds->size() != dependVars.size())
1635 return op->emitOpError() << "expected as many depend values"
1636 " as depend variables";
1637 } else {
1638 if (dependKinds && !dependKinds->empty())
1639 return op->emitOpError() << "unexpected depend values";
1640 return success();
1641 }
1642
1643 return success();
1644}
1645
1646//===----------------------------------------------------------------------===//
1647// Parser, printer and verifier for Synchronization Hint (2.17.12)
1648//===----------------------------------------------------------------------===//
1649
1650/// Parses a Synchronization Hint clause. The value of hint is an integer
1651/// which is a combination of different hints from `omp_sync_hint_t`.
1652///
1653/// hint-clause = `hint` `(` hint-value `)`
1654static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1655 IntegerAttr &hintAttr) {
1656 StringRef hintKeyword;
1657 int64_t hint = 0;
1658 if (succeeded(parser.parseOptionalKeyword("none"))) {
1659 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1660 return success();
1661 }
1662 auto parseKeyword = [&]() -> ParseResult {
1663 if (failed(parser.parseKeyword(&hintKeyword)))
1664 return failure();
1665 if (hintKeyword == "uncontended")
1666 hint |= 1;
1667 else if (hintKeyword == "contended")
1668 hint |= 2;
1669 else if (hintKeyword == "nonspeculative")
1670 hint |= 4;
1671 else if (hintKeyword == "speculative")
1672 hint |= 8;
1673 else
1674 return parser.emitError(parser.getCurrentLocation())
1675 << hintKeyword << " is not a valid hint";
1676 return success();
1677 };
1678 if (parser.parseCommaSeparatedList(parseKeyword))
1679 return failure();
1680 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1681 return success();
1682}
1683
1684/// Prints a Synchronization Hint clause
1686 IntegerAttr hintAttr) {
1687 int64_t hint = hintAttr.getInt();
1688
1689 if (hint == 0) {
1690 p << "none";
1691 return;
1692 }
1693
1694 // Helper function to get n-th bit from the right end of `value`
1695 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1696
1697 bool uncontended = bitn(hint, 0);
1698 bool contended = bitn(hint, 1);
1699 bool nonspeculative = bitn(hint, 2);
1700 bool speculative = bitn(hint, 3);
1701
1703 if (uncontended)
1704 hints.push_back("uncontended");
1705 if (contended)
1706 hints.push_back("contended");
1707 if (nonspeculative)
1708 hints.push_back("nonspeculative");
1709 if (speculative)
1710 hints.push_back("speculative");
1711
1712 llvm::interleaveComma(hints, p);
1713}
1714
1715/// Verifies a synchronization hint clause
1716static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1717
1718 // Helper function to get n-th bit from the right end of `value`
1719 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1720
1721 bool uncontended = bitn(hint, 0);
1722 bool contended = bitn(hint, 1);
1723 bool nonspeculative = bitn(hint, 2);
1724 bool speculative = bitn(hint, 3);
1725
1726 if (uncontended && contended)
1727 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1728 "omp_sync_hint_contended cannot be combined";
1729 if (nonspeculative && speculative)
1730 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1731 "omp_sync_hint_speculative cannot be combined.";
1732 return success();
1733}
1734
1735//===----------------------------------------------------------------------===//
1736// Parser, printer and verifier for Target
1737//===----------------------------------------------------------------------===//
1738
1739// Helper function to get bitwise AND of `value` and 'flag' then return it as a
1740// boolean
1741static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
1742 return (value & flag) == flag;
1743}
1744
1745/// Parses a map_entries map type from a string format back into its numeric
1746/// value.
1747///
1748/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1749/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1750static ParseResult parseMapClause(OpAsmParser &parser,
1751 ClauseMapFlagsAttr &mapType) {
1752 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1753 // This simply verifies the correct keyword is read in, the
1754 // keyword itself is stored inside of the operation
1755 auto parseTypeAndMod = [&]() -> ParseResult {
1756 StringRef mapTypeMod;
1757 if (parser.parseKeyword(&mapTypeMod))
1758 return failure();
1759
1760 if (mapTypeMod == "always")
1761 mapTypeBits |= ClauseMapFlags::always;
1762
1763 if (mapTypeMod == "implicit")
1764 mapTypeBits |= ClauseMapFlags::implicit;
1765
1766 if (mapTypeMod == "ompx_hold")
1767 mapTypeBits |= ClauseMapFlags::ompx_hold;
1768
1769 if (mapTypeMod == "close")
1770 mapTypeBits |= ClauseMapFlags::close;
1771
1772 if (mapTypeMod == "present")
1773 mapTypeBits |= ClauseMapFlags::present;
1774
1775 if (mapTypeMod == "to")
1776 mapTypeBits |= ClauseMapFlags::to;
1777
1778 if (mapTypeMod == "from")
1779 mapTypeBits |= ClauseMapFlags::from;
1780
1781 if (mapTypeMod == "tofrom")
1782 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1783
1784 if (mapTypeMod == "delete")
1785 mapTypeBits |= ClauseMapFlags::del;
1786
1787 if (mapTypeMod == "storage")
1788 mapTypeBits |= ClauseMapFlags::storage;
1789
1790 if (mapTypeMod == "return_param")
1791 mapTypeBits |= ClauseMapFlags::return_param;
1792
1793 if (mapTypeMod == "private")
1794 mapTypeBits |= ClauseMapFlags::priv;
1795
1796 if (mapTypeMod == "literal")
1797 mapTypeBits |= ClauseMapFlags::literal;
1798
1799 if (mapTypeMod == "attach")
1800 mapTypeBits |= ClauseMapFlags::attach;
1801
1802 if (mapTypeMod == "attach_always")
1803 mapTypeBits |= ClauseMapFlags::attach_always;
1804
1805 if (mapTypeMod == "attach_none")
1806 mapTypeBits |= ClauseMapFlags::attach_none;
1807
1808 if (mapTypeMod == "attach_auto")
1809 mapTypeBits |= ClauseMapFlags::attach_auto;
1810
1811 if (mapTypeMod == "ref_ptr")
1812 mapTypeBits |= ClauseMapFlags::ref_ptr;
1813
1814 if (mapTypeMod == "ref_ptee")
1815 mapTypeBits |= ClauseMapFlags::ref_ptee;
1816
1817 if (mapTypeMod == "ref_ptr_ptee")
1818 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1819
1820 return success();
1821 };
1822
1823 if (parser.parseCommaSeparatedList(parseTypeAndMod))
1824 return failure();
1825
1826 mapType =
1827 parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
1828
1829 return success();
1830}
1831
1832/// Prints a map_entries map type from its numeric value out into its string
1833/// format.
1834static void printMapClause(OpAsmPrinter &p, Operation *op,
1835 ClauseMapFlagsAttr mapType) {
1837 ClauseMapFlags mapFlags = mapType.getValue();
1838
1839 // handling of always, close, present placed at the beginning of the string
1840 // to aid readability
1841 if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
1842 mapTypeStrs.push_back("always");
1843 if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
1844 mapTypeStrs.push_back("implicit");
1845 if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
1846 mapTypeStrs.push_back("ompx_hold");
1847 if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
1848 mapTypeStrs.push_back("close");
1849 if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
1850 mapTypeStrs.push_back("present");
1851
1852 // special handling of to/from/tofrom/delete and release/alloc, release +
1853 // alloc are the abscense of one of the other flags, whereas tofrom requires
1854 // both the to and from flag to be set.
1855 bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
1856 bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
1857
1858 if (to && from)
1859 mapTypeStrs.push_back("tofrom");
1860 else if (from)
1861 mapTypeStrs.push_back("from");
1862 else if (to)
1863 mapTypeStrs.push_back("to");
1864
1865 if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
1866 mapTypeStrs.push_back("delete");
1867 if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
1868 mapTypeStrs.push_back("return_param");
1869 if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
1870 mapTypeStrs.push_back("storage");
1871 if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
1872 mapTypeStrs.push_back("private");
1873 if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
1874 mapTypeStrs.push_back("literal");
1875 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
1876 mapTypeStrs.push_back("attach");
1877 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
1878 mapTypeStrs.push_back("attach_always");
1879 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_none))
1880 mapTypeStrs.push_back("attach_none");
1881 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
1882 mapTypeStrs.push_back("attach_auto");
1883 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
1884 mapTypeStrs.push_back("ref_ptr");
1885 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
1886 mapTypeStrs.push_back("ref_ptee");
1887 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
1888 mapTypeStrs.push_back("ref_ptr_ptee");
1889 if (mapFlags == ClauseMapFlags::none)
1890 mapTypeStrs.push_back("none");
1891
1892 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
1893 p << mapTypeStrs[i];
1894 if (i + 1 < mapTypeStrs.size()) {
1895 p << ", ";
1896 }
1897 }
1898}
1899
1900static ParseResult parseMembersIndex(OpAsmParser &parser,
1901 ArrayAttr &membersIdx) {
1902 SmallVector<Attribute> values, memberIdxs;
1903
1904 auto parseIndices = [&]() -> ParseResult {
1905 int64_t value;
1906 if (parser.parseInteger(value))
1907 return failure();
1908 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
1909 APInt(64, value, /*isSigned=*/false)));
1910 return success();
1911 };
1912
1913 do {
1914 if (failed(parser.parseLSquare()))
1915 return failure();
1916
1917 if (parser.parseCommaSeparatedList(parseIndices))
1918 return failure();
1919
1920 if (failed(parser.parseRSquare()))
1921 return failure();
1922
1923 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
1924 values.clear();
1925 } while (succeeded(parser.parseOptionalComma()));
1926
1927 if (!memberIdxs.empty())
1928 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
1929
1930 return success();
1931}
1932
1933static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
1934 ArrayAttr membersIdx) {
1935 if (!membersIdx)
1936 return;
1937
1938 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
1939 p << "[";
1940 auto memberIdx = cast<ArrayAttr>(v);
1941 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
1942 p << cast<IntegerAttr>(v2).getInt();
1943 });
1944 p << "]";
1945 });
1946}
1947
1949 VariableCaptureKindAttr mapCaptureType) {
1950 std::string typeCapStr;
1951 llvm::raw_string_ostream typeCap(typeCapStr);
1952 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
1953 typeCap << "ByRef";
1954 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
1955 typeCap << "ByCopy";
1956 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
1957 typeCap << "VLAType";
1958 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
1959 typeCap << "This";
1960 p << typeCapStr;
1961}
1962
1963static ParseResult parseCaptureType(OpAsmParser &parser,
1964 VariableCaptureKindAttr &mapCaptureType) {
1965 StringRef mapCaptureKey;
1966 if (parser.parseKeyword(&mapCaptureKey))
1967 return failure();
1968
1969 if (mapCaptureKey == "This")
1970 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1971 parser.getContext(), mlir::omp::VariableCaptureKind::This);
1972 if (mapCaptureKey == "ByRef")
1973 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1974 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
1975 if (mapCaptureKey == "ByCopy")
1976 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1977 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
1978 if (mapCaptureKey == "VLAType")
1979 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
1980 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
1981
1982 return success();
1983}
1984
1985static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
1988
1989 for (auto mapOp : mapVars) {
1990 if (!mapOp.getDefiningOp())
1991 return emitError(op->getLoc(), "missing map operation");
1992
1993 if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
1994 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
1995
1996 bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
1997 bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
1998 bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
1999
2000 bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2001 bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2002 bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2003
2004 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2005 return emitError(op->getLoc(),
2006 "to, from, tofrom and alloc map types are permitted");
2007
2008 if (isa<TargetEnterDataOp>(op) && (from || del))
2009 return emitError(op->getLoc(), "to and alloc map types are permitted");
2010
2011 if (isa<TargetExitDataOp>(op) && to)
2012 return emitError(op->getLoc(),
2013 "from, release and delete map types are permitted");
2014
2015 if (isa<TargetUpdateOp>(op)) {
2016 if (del) {
2017 return emitError(op->getLoc(),
2018 "at least one of to or from map types must be "
2019 "specified, other map types are not permitted");
2020 }
2021
2022 if (!to && !from) {
2023 return emitError(op->getLoc(),
2024 "at least one of to or from map types must be "
2025 "specified, other map types are not permitted");
2026 }
2027
2028 auto updateVar = mapInfoOp.getVarPtr();
2029
2030 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2031 (from && updateToVars.contains(updateVar))) {
2032 return emitError(
2033 op->getLoc(),
2034 "either to or from map types can be specified, not both");
2035 }
2036
2037 if (always || close || implicit) {
2038 return emitError(
2039 op->getLoc(),
2040 "present, mapper and iterator map type modifiers are permitted");
2041 }
2042
2043 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2044 }
2045 } else if (!isa<DeclareMapperInfoOp>(op)) {
2046 return emitError(op->getLoc(),
2047 "map argument is not a map entry operation");
2048 }
2049 }
2050
2051 return success();
2052}
2053
2054static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2055 std::optional<DenseI64ArrayAttr> privateMapIndices =
2056 targetOp.getPrivateMapsAttr();
2057
2058 // None of the private operands are mapped.
2059 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2060 return success();
2061
2062 OperandRange privateVars = targetOp.getPrivateVars();
2063
2064 if (privateMapIndices.value().size() !=
2065 static_cast<int64_t>(privateVars.size()))
2066 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2067 "`private_maps` attribute mismatch");
2068
2069 return success();
2070}
2071
2072//===----------------------------------------------------------------------===//
2073// MapInfoOp
2074//===----------------------------------------------------------------------===//
2075
2076static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2077 StringRef clauseName,
2078 OperandRange vars) {
2079 for (Value var : vars)
2080 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2081 return op->emitOpError()
2082 << "'" << clauseName
2083 << "' arguments must be defined by 'omp.map.info' ops";
2084 return success();
2085}
2086
2087LogicalResult MapInfoOp::verify() {
2088 if (getMapperId() &&
2090 *this, getMapperIdAttr())) {
2091 return emitError("invalid mapper id");
2092 }
2093
2094 if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2095 return failure();
2096
2097 return success();
2098}
2099
2100//===----------------------------------------------------------------------===//
2101// TargetDataOp
2102//===----------------------------------------------------------------------===//
2103
2104void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2105 const TargetDataOperands &clauses) {
2106 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2107 clauses.mapVars, clauses.useDeviceAddrVars,
2108 clauses.useDevicePtrVars);
2109}
2110
2111LogicalResult TargetDataOp::verify() {
2112 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2113 getUseDeviceAddrVars().empty()) {
2114 return ::emitError(this->getLoc(),
2115 "At least one of map, use_device_ptr_vars, or "
2116 "use_device_addr_vars operand must be present");
2117 }
2118
2119 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2120 getUseDevicePtrVars())))
2121 return failure();
2122
2123 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2124 getUseDeviceAddrVars())))
2125 return failure();
2126
2127 return verifyMapClause(*this, getMapVars());
2128}
2129
2130//===----------------------------------------------------------------------===//
2131// TargetEnterDataOp
2132//===----------------------------------------------------------------------===//
2133
2134void TargetEnterDataOp::build(
2135 OpBuilder &builder, OperationState &state,
2136 const TargetEnterExitUpdateDataOperands &clauses) {
2137 MLIRContext *ctx = builder.getContext();
2138 TargetEnterDataOp::build(builder, state,
2139 makeArrayAttr(ctx, clauses.dependKinds),
2140 clauses.dependVars, clauses.device, clauses.ifExpr,
2141 clauses.mapVars, clauses.nowait);
2142}
2143
2144LogicalResult TargetEnterDataOp::verify() {
2145 LogicalResult verifyDependVars =
2146 verifyDependVarList(*this, getDependKinds(), getDependVars());
2147 return failed(verifyDependVars) ? verifyDependVars
2148 : verifyMapClause(*this, getMapVars());
2149}
2150
2151//===----------------------------------------------------------------------===//
2152// TargetExitDataOp
2153//===----------------------------------------------------------------------===//
2154
2155void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2156 const TargetEnterExitUpdateDataOperands &clauses) {
2157 MLIRContext *ctx = builder.getContext();
2158 TargetExitDataOp::build(builder, state,
2159 makeArrayAttr(ctx, clauses.dependKinds),
2160 clauses.dependVars, clauses.device, clauses.ifExpr,
2161 clauses.mapVars, clauses.nowait);
2162}
2163
2164LogicalResult TargetExitDataOp::verify() {
2165 LogicalResult verifyDependVars =
2166 verifyDependVarList(*this, getDependKinds(), getDependVars());
2167 return failed(verifyDependVars) ? verifyDependVars
2168 : verifyMapClause(*this, getMapVars());
2169}
2170
2171//===----------------------------------------------------------------------===//
2172// TargetUpdateOp
2173//===----------------------------------------------------------------------===//
2174
2175void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2176 const TargetEnterExitUpdateDataOperands &clauses) {
2177 MLIRContext *ctx = builder.getContext();
2178 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2179 clauses.dependVars, clauses.device, clauses.ifExpr,
2180 clauses.mapVars, clauses.nowait);
2181}
2182
2183LogicalResult TargetUpdateOp::verify() {
2184 LogicalResult verifyDependVars =
2185 verifyDependVarList(*this, getDependKinds(), getDependVars());
2186 return failed(verifyDependVars) ? verifyDependVars
2187 : verifyMapClause(*this, getMapVars());
2188}
2189
2190//===----------------------------------------------------------------------===//
2191// TargetOp
2192//===----------------------------------------------------------------------===//
2193
2194void TargetOp::build(OpBuilder &builder, OperationState &state,
2195 const TargetOperands &clauses) {
2196 MLIRContext *ctx = builder.getContext();
2197 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2198 // inReductionByref, inReductionSyms.
2199 TargetOp::build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2200 clauses.bare, makeArrayAttr(ctx, clauses.dependKinds),
2201 clauses.dependVars, clauses.device, clauses.hasDeviceAddrVars,
2202 clauses.hostEvalVars, clauses.ifExpr,
2203 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2204 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars,
2205 clauses.mapVars, clauses.nowait, clauses.privateVars,
2206 makeArrayAttr(ctx, clauses.privateSyms),
2207 clauses.privateNeedsBarrier, clauses.threadLimit,
2208 /*private_maps=*/nullptr);
2209}
2210
2211LogicalResult TargetOp::verify() {
2212 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars())))
2213 return failure();
2214
2215 if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2216 getHasDeviceAddrVars())))
2217 return failure();
2218
2219 if (failed(verifyMapClause(*this, getMapVars())))
2220 return failure();
2221
2222 return verifyPrivateVarsMapping(*this);
2223}
2224
2225LogicalResult TargetOp::verifyRegions() {
2226 auto teamsOps = getOps<TeamsOp>();
2227 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2228 return emitError("target containing multiple 'omp.teams' nested ops");
2229
2230 // Check that host_eval values are only used in legal ways.
2231 Operation *capturedOp = getInnermostCapturedOmpOp();
2232 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2233 for (Value hostEvalArg :
2234 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2235 for (Operation *user : hostEvalArg.getUsers()) {
2236 if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2237 if (llvm::is_contained({teamsOp.getNumTeamsLower(),
2238 teamsOp.getNumTeamsUpper(),
2239 teamsOp.getThreadLimit()},
2240 hostEvalArg))
2241 continue;
2242
2243 return emitOpError() << "host_eval argument only legal as 'num_teams' "
2244 "and 'thread_limit' in 'omp.teams'";
2245 }
2246 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2247 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2248 parallelOp->isAncestor(capturedOp) &&
2249 hostEvalArg == parallelOp.getNumThreads())
2250 continue;
2251
2252 return emitOpError()
2253 << "host_eval argument only legal as 'num_threads' in "
2254 "'omp.parallel' when representing target SPMD";
2255 }
2256 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2257 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2258 loopNestOp.getOperation() == capturedOp &&
2259 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2260 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2261 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2262 continue;
2263
2264 return emitOpError() << "host_eval argument only legal as loop bounds "
2265 "and steps in 'omp.loop_nest' when trip count "
2266 "must be evaluated in the host";
2267 }
2268
2269 return emitOpError() << "host_eval argument illegal use in '"
2270 << user->getName() << "' operation";
2271 }
2272 }
2273 return success();
2274}
2275
2276static Operation *
2277findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2278 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2279 assert(rootOp && "expected valid operation");
2280
2281 Dialect *ompDialect = rootOp->getDialect();
2282 Operation *capturedOp = nullptr;
2283 DominanceInfo domInfo;
2284
2285 // Process in pre-order to check operations from outermost to innermost,
2286 // ensuring we only enter the region of an operation if it meets the criteria
2287 // for being captured. We stop the exploration of nested operations as soon as
2288 // we process a region holding no operations to be captured.
2289 rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2290 if (op == rootOp)
2291 return WalkResult::advance();
2292
2293 // Ignore operations of other dialects or omp operations with no regions,
2294 // because these will only be checked if they are siblings of an omp
2295 // operation that can potentially be captured.
2296 bool isOmpDialect = op->getDialect() == ompDialect;
2297 bool hasRegions = op->getNumRegions() > 0;
2298 if (!isOmpDialect || !hasRegions)
2299 return WalkResult::skip();
2300
2301 // This operation cannot be captured if it can be executed more than once
2302 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2303 // be executed before all exits of the region (i.e. it doesn't dominate all
2304 // blocks with no successors reachable from the entry block).
2305 if (checkSingleMandatoryExec) {
2306 Region *parentRegion = op->getParentRegion();
2307 Block *parentBlock = op->getBlock();
2308
2309 for (Block *successor : parentBlock->getSuccessors())
2310 if (successor->isReachable(parentBlock))
2311 return WalkResult::interrupt();
2312
2313 for (Block &block : *parentRegion)
2314 if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2315 !domInfo.dominates(parentBlock, &block))
2316 return WalkResult::interrupt();
2317 }
2318
2319 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2320 // into nested operations.
2321 for (Operation &sibling : op->getParentRegion()->getOps())
2322 if (&sibling != op && !siblingAllowedFn(&sibling))
2323 return WalkResult::interrupt();
2324
2325 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2326 // Otherwise, process the contents of this operation.
2327 capturedOp = op;
2328 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2330 });
2331
2332 return capturedOp;
2333}
2334
2335Operation *TargetOp::getInnermostCapturedOmpOp() {
2336 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2337
2338 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2339 // effects, but don't include a memory write effect.
2340 return findCapturedOmpOp(
2341 *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2342 if (!sibling)
2343 return false;
2344
2345 if (ompDialect == sibling->getDialect())
2346 return sibling->hasTrait<OpTrait::IsTerminator>();
2347
2348 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2350 effects;
2351 memOp.getEffects(effects);
2352 return !llvm::any_of(
2353 effects, [&](MemoryEffects::EffectInstance &effect) {
2354 return isa<MemoryEffects::Write>(effect.getEffect()) &&
2355 isa<SideEffects::AutomaticAllocationScopeResource>(
2356 effect.getResource());
2357 });
2358 }
2359 return true;
2360 });
2361}
2362
2363/// Check if we can promote SPMD kernel to No-Loop kernel.
2364static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2365 WsloopOp *wsLoopOp) {
2366 // num_teams clause can break no-loop teams/threads assumption.
2367 if (teamsOp.getNumTeamsUpper())
2368 return false;
2369
2370 // Reduction kernels are slower in no-loop mode.
2371 if (teamsOp.getNumReductionVars())
2372 return false;
2373 if (wsLoopOp->getNumReductionVars())
2374 return false;
2375
2376 // Check if the user allows the promotion of kernels to no-loop mode.
2377 OffloadModuleInterface offloadMod =
2378 capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2379 if (!offloadMod)
2380 return false;
2381 auto ompFlags = offloadMod.getFlags();
2382 if (!ompFlags)
2383 return false;
2384 return ompFlags.getAssumeTeamsOversubscription() &&
2385 ompFlags.getAssumeThreadsOversubscription();
2386}
2387
2388TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2389 // A non-null captured op is only valid if it resides inside of a TargetOp
2390 // and is the result of calling getInnermostCapturedOmpOp() on it.
2391 TargetOp targetOp =
2392 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2393 assert((!capturedOp ||
2394 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2395 "unexpected captured op");
2396
2397 // If it's not capturing a loop, it's a default target region.
2398 if (!isa_and_present<LoopNestOp>(capturedOp))
2399 return TargetRegionFlags::generic;
2400
2401 // Get the innermost non-simd loop wrapper.
2403 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2404 assert(!loopWrappers.empty());
2405
2406 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2407 if (isa<SimdOp>(innermostWrapper))
2408 innermostWrapper = std::next(innermostWrapper);
2409
2410 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2411 if (numWrappers != 1 && numWrappers != 2)
2412 return TargetRegionFlags::generic;
2413
2414 // Detect target-teams-distribute-parallel-wsloop[-simd].
2415 if (numWrappers == 2) {
2416 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2417 if (!wsloopOp)
2418 return TargetRegionFlags::generic;
2419
2420 innermostWrapper = std::next(innermostWrapper);
2421 if (!isa<DistributeOp>(innermostWrapper))
2422 return TargetRegionFlags::generic;
2423
2424 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2425 if (!isa_and_present<ParallelOp>(parallelOp))
2426 return TargetRegionFlags::generic;
2427
2428 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2429 if (!teamsOp)
2430 return TargetRegionFlags::generic;
2431
2432 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2433 TargetRegionFlags result =
2434 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2435 if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2436 result = result | TargetRegionFlags::no_loop;
2437 return result;
2438 }
2439 }
2440 // Detect target-teams-distribute[-simd] and target-teams-loop.
2441 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2442 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2443 if (!isa_and_present<TeamsOp>(teamsOp))
2444 return TargetRegionFlags::generic;
2445
2446 if (teamsOp->getParentOp() != targetOp.getOperation())
2447 return TargetRegionFlags::generic;
2448
2449 if (isa<LoopOp>(innermostWrapper))
2450 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2451
2452 // Find single immediately nested captured omp.parallel and add spmd flag
2453 // (generic-spmd case).
2454 //
2455 // TODO: This shouldn't have to be done here, as it is too easy to break.
2456 // The openmp-opt pass should be updated to be able to promote kernels like
2457 // this from "Generic" to "Generic-SPMD". However, the use of the
2458 // `kmpc_distribute_static_loop` family of functions produced by the
2459 // OMPIRBuilder for these kernels prevents that from working.
2460 Dialect *ompDialect = targetOp->getDialect();
2461 Operation *nestedCapture = findCapturedOmpOp(
2462 capturedOp, /*checkSingleMandatoryExec=*/false,
2463 [&](Operation *sibling) {
2464 return sibling && (ompDialect != sibling->getDialect() ||
2465 sibling->hasTrait<OpTrait::IsTerminator>());
2466 });
2467
2468 TargetRegionFlags result =
2469 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2470
2471 if (!nestedCapture)
2472 return result;
2473
2474 while (nestedCapture->getParentOp() != capturedOp)
2475 nestedCapture = nestedCapture->getParentOp();
2476
2477 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2478 : result;
2479 }
2480 // Detect target-parallel-wsloop[-simd].
2481 else if (isa<WsloopOp>(innermostWrapper)) {
2482 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2483 if (!isa_and_present<ParallelOp>(parallelOp))
2484 return TargetRegionFlags::generic;
2485
2486 if (parallelOp->getParentOp() == targetOp.getOperation())
2487 return TargetRegionFlags::spmd;
2488 }
2489
2490 return TargetRegionFlags::generic;
2491}
2492
2493//===----------------------------------------------------------------------===//
2494// ParallelOp
2495//===----------------------------------------------------------------------===//
2496
2497void ParallelOp::build(OpBuilder &builder, OperationState &state,
2498 ArrayRef<NamedAttribute> attributes) {
2499 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2500 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2501 /*num_threads=*/nullptr, /*private_vars=*/ValueRange(),
2502 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2503 /*proc_bind_kind=*/nullptr,
2504 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2505 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2506 state.addAttributes(attributes);
2507}
2508
2509void ParallelOp::build(OpBuilder &builder, OperationState &state,
2510 const ParallelOperands &clauses) {
2511 MLIRContext *ctx = builder.getContext();
2512 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2513 clauses.ifExpr, clauses.numThreads, clauses.privateVars,
2514 makeArrayAttr(ctx, clauses.privateSyms),
2515 clauses.privateNeedsBarrier, clauses.procBindKind,
2516 clauses.reductionMod, clauses.reductionVars,
2517 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2518 makeArrayAttr(ctx, clauses.reductionSyms));
2519}
2520
2521template <typename OpType>
2522static LogicalResult verifyPrivateVarList(OpType &op) {
2523 auto privateVars = op.getPrivateVars();
2524 auto privateSyms = op.getPrivateSymsAttr();
2525
2526 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2527 return success();
2528
2529 auto numPrivateVars = privateVars.size();
2530 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2531
2532 if (numPrivateVars != numPrivateSyms)
2533 return op.emitError() << "inconsistent number of private variables and "
2534 "privatizer op symbols, private vars: "
2535 << numPrivateVars
2536 << " vs. privatizer op symbols: " << numPrivateSyms;
2537
2538 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2539 Type varType = std::get<0>(privateVarInfo).getType();
2540 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2541 PrivateClauseOp privatizerOp =
2543
2544 if (privatizerOp == nullptr)
2545 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2546 << privateSym << "'";
2547
2548 Type privatizerType = privatizerOp.getArgType();
2549
2550 if (privatizerType && (varType != privatizerType))
2551 return op.emitError()
2552 << "type mismatch between a "
2553 << (privatizerOp.getDataSharingType() ==
2554 DataSharingClauseType::Private
2555 ? "private"
2556 : "firstprivate")
2557 << " variable and its privatizer op, var type: " << varType
2558 << " vs. privatizer op type: " << privatizerType;
2559 }
2560
2561 return success();
2562}
2563
2564LogicalResult ParallelOp::verify() {
2565 if (getAllocateVars().size() != getAllocatorVars().size())
2566 return emitError(
2567 "expected equal sizes for allocate and allocator variables");
2568
2569 if (failed(verifyPrivateVarList(*this)))
2570 return failure();
2571
2572 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2573 getReductionByref());
2574}
2575
2576LogicalResult ParallelOp::verifyRegions() {
2577 auto distChildOps = getOps<DistributeOp>();
2578 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2579 if (numDistChildOps > 1)
2580 return emitError()
2581 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2582
2583 if (numDistChildOps == 1) {
2584 if (!isComposite())
2585 return emitError()
2586 << "'omp.composite' attribute missing from composite operation";
2587
2588 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2589 Operation &distributeOp = **distChildOps.begin();
2590 for (Operation &childOp : getOps()) {
2591 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2592 continue;
2593
2594 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2595 return emitError() << "unexpected OpenMP operation inside of composite "
2596 "'omp.parallel': "
2597 << childOp.getName();
2598 }
2599 } else if (isComposite()) {
2600 return emitError()
2601 << "'omp.composite' attribute present in non-composite operation";
2602 }
2603 return success();
2604}
2605
2606//===----------------------------------------------------------------------===//
2607// TeamsOp
2608//===----------------------------------------------------------------------===//
2609
2611 while ((op = op->getParentOp()))
2612 if (isa<OpenMPDialect>(op->getDialect()))
2613 return false;
2614 return true;
2615}
2616
2617void TeamsOp::build(OpBuilder &builder, OperationState &state,
2618 const TeamsOperands &clauses) {
2619 MLIRContext *ctx = builder.getContext();
2620 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2621 TeamsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2622 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpper,
2623 /*private_vars=*/{}, /*private_syms=*/nullptr,
2624 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2625 clauses.reductionVars,
2626 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2627 makeArrayAttr(ctx, clauses.reductionSyms),
2628 clauses.threadLimit);
2629}
2630
2631LogicalResult TeamsOp::verify() {
2632 // Check parent region
2633 // TODO If nested inside of a target region, also check that it does not
2634 // contain any statements, declarations or directives other than this
2635 // omp.teams construct. The issue is how to support the initialization of
2636 // this operation's own arguments (allow SSA values across omp.target?).
2637 Operation *op = getOperation();
2638 if (!isa<TargetOp>(op->getParentOp()) &&
2640 return emitError("expected to be nested inside of omp.target or not nested "
2641 "in any OpenMP dialect operations");
2642
2643 // Check for num_teams clause restrictions
2644 if (auto numTeamsLowerBound = getNumTeamsLower()) {
2645 auto numTeamsUpperBound = getNumTeamsUpper();
2646 if (!numTeamsUpperBound)
2647 return emitError("expected num_teams upper bound to be defined if the "
2648 "lower bound is defined");
2649 if (numTeamsLowerBound.getType() != numTeamsUpperBound.getType())
2650 return emitError(
2651 "expected num_teams upper bound and lower bound to be the same type");
2652 }
2653
2654 // Check for allocate clause restrictions
2655 if (getAllocateVars().size() != getAllocatorVars().size())
2656 return emitError(
2657 "expected equal sizes for allocate and allocator variables");
2658
2659 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2660 getReductionByref());
2661}
2662
2663//===----------------------------------------------------------------------===//
2664// SectionOp
2665//===----------------------------------------------------------------------===//
2666
2667OperandRange SectionOp::getPrivateVars() {
2668 return getParentOp().getPrivateVars();
2669}
2670
2671OperandRange SectionOp::getReductionVars() {
2672 return getParentOp().getReductionVars();
2673}
2674
2675//===----------------------------------------------------------------------===//
2676// SectionsOp
2677//===----------------------------------------------------------------------===//
2678
2679void SectionsOp::build(OpBuilder &builder, OperationState &state,
2680 const SectionsOperands &clauses) {
2681 MLIRContext *ctx = builder.getContext();
2682 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2683 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2684 clauses.nowait, /*private_vars=*/{},
2685 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2686 clauses.reductionMod, clauses.reductionVars,
2687 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2688 makeArrayAttr(ctx, clauses.reductionSyms));
2689}
2690
2691LogicalResult SectionsOp::verify() {
2692 if (getAllocateVars().size() != getAllocatorVars().size())
2693 return emitError(
2694 "expected equal sizes for allocate and allocator variables");
2695
2696 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2697 getReductionByref());
2698}
2699
2700LogicalResult SectionsOp::verifyRegions() {
2701 for (auto &inst : *getRegion().begin()) {
2702 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2703 return emitOpError()
2704 << "expected omp.section op or terminator op inside region";
2705 }
2706 }
2707
2708 return success();
2709}
2710
2711//===----------------------------------------------------------------------===//
2712// SingleOp
2713//===----------------------------------------------------------------------===//
2714
2715void SingleOp::build(OpBuilder &builder, OperationState &state,
2716 const SingleOperands &clauses) {
2717 MLIRContext *ctx = builder.getContext();
2718 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2719 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2720 clauses.copyprivateVars,
2721 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2722 /*private_vars=*/{}, /*private_syms=*/nullptr,
2723 /*private_needs_barrier=*/nullptr);
2724}
2725
2726LogicalResult SingleOp::verify() {
2727 // Check for allocate clause restrictions
2728 if (getAllocateVars().size() != getAllocatorVars().size())
2729 return emitError(
2730 "expected equal sizes for allocate and allocator variables");
2731
2732 return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2733 getCopyprivateSyms());
2734}
2735
2736//===----------------------------------------------------------------------===//
2737// WorkshareOp
2738//===----------------------------------------------------------------------===//
2739
2740void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2741 const WorkshareOperands &clauses) {
2742 WorkshareOp::build(builder, state, clauses.nowait);
2743}
2744
2745//===----------------------------------------------------------------------===//
2746// WorkshareLoopWrapperOp
2747//===----------------------------------------------------------------------===//
2748
2749LogicalResult WorkshareLoopWrapperOp::verify() {
2750 if (!(*this)->getParentOfType<WorkshareOp>())
2751 return emitOpError() << "must be nested in an omp.workshare";
2752 return success();
2753}
2754
2755LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2756 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2757 getNestedWrapper())
2758 return emitOpError() << "expected to be a standalone loop wrapper";
2759
2760 return success();
2761}
2762
2763//===----------------------------------------------------------------------===//
2764// LoopWrapperInterface
2765//===----------------------------------------------------------------------===//
2766
2767LogicalResult LoopWrapperInterface::verifyImpl() {
2768 Operation *op = this->getOperation();
2769 if (!op->hasTrait<OpTrait::NoTerminator>() ||
2771 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2772 "and `SingleBlock` traits";
2773
2774 if (op->getNumRegions() != 1)
2775 return emitOpError() << "loop wrapper does not contain exactly one region";
2776
2777 Region &region = op->getRegion(0);
2778 if (range_size(region.getOps()) != 1)
2779 return emitOpError()
2780 << "loop wrapper does not contain exactly one nested op";
2781
2782 Operation &firstOp = *region.op_begin();
2783 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2784 return emitOpError() << "nested in loop wrapper is not another loop "
2785 "wrapper or `omp.loop_nest`";
2786
2787 return success();
2788}
2789
2790//===----------------------------------------------------------------------===//
2791// LoopOp
2792//===----------------------------------------------------------------------===//
2793
2794void LoopOp::build(OpBuilder &builder, OperationState &state,
2795 const LoopOperands &clauses) {
2796 MLIRContext *ctx = builder.getContext();
2797
2798 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2799 makeArrayAttr(ctx, clauses.privateSyms),
2800 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2801 clauses.reductionMod, clauses.reductionVars,
2802 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2803 makeArrayAttr(ctx, clauses.reductionSyms));
2804}
2805
2806LogicalResult LoopOp::verify() {
2807 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2808 getReductionByref());
2809}
2810
2811LogicalResult LoopOp::verifyRegions() {
2812 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2813 getNestedWrapper())
2814 return emitOpError() << "expected to be a standalone loop wrapper";
2815
2816 return success();
2817}
2818
2819//===----------------------------------------------------------------------===//
2820// WsloopOp
2821//===----------------------------------------------------------------------===//
2822
2823void WsloopOp::build(OpBuilder &builder, OperationState &state,
2824 ArrayRef<NamedAttribute> attributes) {
2825 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
2826 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
2827 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
2828 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
2829 /*private_needs_barrier=*/false,
2830 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
2831 /*reduction_byref=*/nullptr,
2832 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
2833 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
2834 /*schedule_simd=*/false);
2835 state.addAttributes(attributes);
2836}
2837
2838void WsloopOp::build(OpBuilder &builder, OperationState &state,
2839 const WsloopOperands &clauses) {
2840 MLIRContext *ctx = builder.getContext();
2841 // TODO: Store clauses in op: allocateVars, allocatorVars
2842 WsloopOp::build(
2843 builder, state,
2844 /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
2845 clauses.linearStepVars, clauses.nowait, clauses.order, clauses.orderMod,
2846 clauses.ordered, clauses.privateVars,
2847 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2848 clauses.reductionMod, clauses.reductionVars,
2849 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2850 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
2851 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
2852}
2853
2854LogicalResult WsloopOp::verify() {
2855 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2856 getReductionByref());
2857}
2858
2859LogicalResult WsloopOp::verifyRegions() {
2860 bool isCompositeChildLeaf =
2861 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2862
2863 if (LoopWrapperInterface nested = getNestedWrapper()) {
2864 if (!isComposite())
2865 return emitError()
2866 << "'omp.composite' attribute missing from composite wrapper";
2867
2868 // Check for the allowed leaf constructs that may appear in a composite
2869 // construct directly after DO/FOR.
2870 if (!isa<SimdOp>(nested))
2871 return emitError() << "only supported nested wrapper is 'omp.simd'";
2872
2873 } else if (isComposite() && !isCompositeChildLeaf) {
2874 return emitError()
2875 << "'omp.composite' attribute present in non-composite wrapper";
2876 } else if (!isComposite() && isCompositeChildLeaf) {
2877 return emitError()
2878 << "'omp.composite' attribute missing from composite wrapper";
2879 }
2880
2881 return success();
2882}
2883
2884//===----------------------------------------------------------------------===//
2885// Simd construct [2.9.3.1]
2886//===----------------------------------------------------------------------===//
2887
2888void SimdOp::build(OpBuilder &builder, OperationState &state,
2889 const SimdOperands &clauses) {
2890 MLIRContext *ctx = builder.getContext();
2891 // TODO Store clauses in op: linearVars, linearStepVars
2892 SimdOp::build(builder, state, clauses.alignedVars,
2893 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
2894 /*linear_vars=*/{}, /*linear_step_vars=*/{},
2895 clauses.nontemporalVars, clauses.order, clauses.orderMod,
2896 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
2897 clauses.privateNeedsBarrier, clauses.reductionMod,
2898 clauses.reductionVars,
2899 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2900 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
2901 clauses.simdlen);
2902}
2903
2904LogicalResult SimdOp::verify() {
2905 if (getSimdlen().has_value() && getSafelen().has_value() &&
2906 getSimdlen().value() > getSafelen().value())
2907 return emitOpError()
2908 << "simdlen clause and safelen clause are both present, but the "
2909 "simdlen value is not less than or equal to safelen value";
2910
2911 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
2912 return failure();
2913
2914 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
2915 return failure();
2916
2917 bool isCompositeChildLeaf =
2918 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
2919
2920 if (!isComposite() && isCompositeChildLeaf)
2921 return emitError()
2922 << "'omp.composite' attribute missing from composite wrapper";
2923
2924 if (isComposite() && !isCompositeChildLeaf)
2925 return emitError()
2926 << "'omp.composite' attribute present in non-composite wrapper";
2927
2928 // Firstprivate is not allowed for SIMD in the standard. Check that none of
2929 // the private decls are for firstprivate.
2930 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
2931 if (privateSyms) {
2932 for (const Attribute &sym : *privateSyms) {
2933 auto symRef = cast<SymbolRefAttr>(sym);
2934 omp::PrivateClauseOp privatizer =
2936 getOperation(), symRef);
2937 if (!privatizer)
2938 return emitError() << "Cannot find privatizer '" << symRef << "'";
2939 if (privatizer.getDataSharingType() ==
2940 DataSharingClauseType::FirstPrivate)
2941 return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
2942 }
2943 }
2944
2945 return success();
2946}
2947
2948LogicalResult SimdOp::verifyRegions() {
2949 if (getNestedWrapper())
2950 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
2951
2952 return success();
2953}
2954
2955//===----------------------------------------------------------------------===//
2956// Distribute construct [2.9.4.1]
2957//===----------------------------------------------------------------------===//
2958
2959void DistributeOp::build(OpBuilder &builder, OperationState &state,
2960 const DistributeOperands &clauses) {
2961 DistributeOp::build(builder, state, clauses.allocateVars,
2962 clauses.allocatorVars, clauses.distScheduleStatic,
2963 clauses.distScheduleChunkSize, clauses.order,
2964 clauses.orderMod, clauses.privateVars,
2965 makeArrayAttr(builder.getContext(), clauses.privateSyms),
2966 clauses.privateNeedsBarrier);
2967}
2968
2969LogicalResult DistributeOp::verify() {
2970 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
2971 return emitOpError() << "chunk size set without "
2972 "dist_schedule_static being present";
2973
2974 if (getAllocateVars().size() != getAllocatorVars().size())
2975 return emitError(
2976 "expected equal sizes for allocate and allocator variables");
2977
2978 return success();
2979}
2980
2981LogicalResult DistributeOp::verifyRegions() {
2982 if (LoopWrapperInterface nested = getNestedWrapper()) {
2983 if (!isComposite())
2984 return emitError()
2985 << "'omp.composite' attribute missing from composite wrapper";
2986 // Check for the allowed leaf constructs that may appear in a composite
2987 // construct directly after DISTRIBUTE.
2988 if (isa<WsloopOp>(nested)) {
2989 Operation *parentOp = (*this)->getParentOp();
2990 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
2991 !cast<ComposableOpInterface>(parentOp).isComposite()) {
2992 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
2993 "when a composite 'omp.parallel' is the direct "
2994 "parent";
2995 }
2996 } else if (!isa<SimdOp>(nested))
2997 return emitError() << "only supported nested wrappers are 'omp.simd' and "
2998 "'omp.wsloop'";
2999 } else if (isComposite()) {
3000 return emitError()
3001 << "'omp.composite' attribute present in non-composite wrapper";
3002 }
3003
3004 return success();
3005}
3006
3007//===----------------------------------------------------------------------===//
3008// DeclareMapperOp / DeclareMapperInfoOp
3009//===----------------------------------------------------------------------===//
3010
3011LogicalResult DeclareMapperInfoOp::verify() {
3012 return verifyMapClause(*this, getMapVars());
3013}
3014
3015LogicalResult DeclareMapperOp::verifyRegions() {
3016 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3017 getRegion().getBlocks().front().getTerminator()))
3018 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
3019
3020 return success();
3021}
3022
3023//===----------------------------------------------------------------------===//
3024// DeclareReductionOp
3025//===----------------------------------------------------------------------===//
3026
3027LogicalResult DeclareReductionOp::verifyRegions() {
3028 if (!getAllocRegion().empty()) {
3029 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3030 if (yieldOp.getResults().size() != 1 ||
3031 yieldOp.getResults().getTypes()[0] != getType())
3032 return emitOpError() << "expects alloc region to yield a value "
3033 "of the reduction type";
3034 }
3035 }
3036
3037 if (getInitializerRegion().empty())
3038 return emitOpError() << "expects non-empty initializer region";
3039 Block &initializerEntryBlock = getInitializerRegion().front();
3040
3041 if (initializerEntryBlock.getNumArguments() == 1) {
3042 if (!getAllocRegion().empty())
3043 return emitOpError() << "expects two arguments to the initializer region "
3044 "when an allocation region is used";
3045 } else if (initializerEntryBlock.getNumArguments() == 2) {
3046 if (getAllocRegion().empty())
3047 return emitOpError() << "expects one argument to the initializer region "
3048 "when no allocation region is used";
3049 } else {
3050 return emitOpError()
3051 << "expects one or two arguments to the initializer region";
3052 }
3053
3054 for (mlir::Value arg : initializerEntryBlock.getArguments())
3055 if (arg.getType() != getType())
3056 return emitOpError() << "expects initializer region argument to match "
3057 "the reduction type";
3058
3059 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3060 if (yieldOp.getResults().size() != 1 ||
3061 yieldOp.getResults().getTypes()[0] != getType())
3062 return emitOpError() << "expects initializer region to yield a value "
3063 "of the reduction type";
3064 }
3065
3066 if (getReductionRegion().empty())
3067 return emitOpError() << "expects non-empty reduction region";
3068 Block &reductionEntryBlock = getReductionRegion().front();
3069 if (reductionEntryBlock.getNumArguments() != 2 ||
3070 reductionEntryBlock.getArgumentTypes()[0] !=
3071 reductionEntryBlock.getArgumentTypes()[1] ||
3072 reductionEntryBlock.getArgumentTypes()[0] != getType())
3073 return emitOpError() << "expects reduction region with two arguments of "
3074 "the reduction type";
3075 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3076 if (yieldOp.getResults().size() != 1 ||
3077 yieldOp.getResults().getTypes()[0] != getType())
3078 return emitOpError() << "expects reduction region to yield a value "
3079 "of the reduction type";
3080 }
3081
3082 if (!getAtomicReductionRegion().empty()) {
3083 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3084 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3085 atomicReductionEntryBlock.getArgumentTypes()[0] !=
3086 atomicReductionEntryBlock.getArgumentTypes()[1])
3087 return emitOpError() << "expects atomic reduction region with two "
3088 "arguments of the same type";
3089 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3090 atomicReductionEntryBlock.getArgumentTypes()[0]);
3091 if (!ptrType ||
3092 (ptrType.getElementType() && ptrType.getElementType() != getType()))
3093 return emitOpError() << "expects atomic reduction region arguments to "
3094 "be accumulators containing the reduction type";
3095 }
3096
3097 if (getCleanupRegion().empty())
3098 return success();
3099 Block &cleanupEntryBlock = getCleanupRegion().front();
3100 if (cleanupEntryBlock.getNumArguments() != 1 ||
3101 cleanupEntryBlock.getArgument(0).getType() != getType())
3102 return emitOpError() << "expects cleanup region with one argument "
3103 "of the reduction type";
3104
3105 return success();
3106}
3107
3108//===----------------------------------------------------------------------===//
3109// TaskOp
3110//===----------------------------------------------------------------------===//
3111
3112void TaskOp::build(OpBuilder &builder, OperationState &state,
3113 const TaskOperands &clauses) {
3114 MLIRContext *ctx = builder.getContext();
3115 TaskOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3116 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3117 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3118 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3119 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3120 clauses.priority, /*private_vars=*/clauses.privateVars,
3121 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3122 clauses.privateNeedsBarrier, clauses.untied,
3123 clauses.eventHandle);
3124}
3125
3126LogicalResult TaskOp::verify() {
3127 LogicalResult verifyDependVars =
3128 verifyDependVarList(*this, getDependKinds(), getDependVars());
3129 return failed(verifyDependVars)
3130 ? verifyDependVars
3131 : verifyReductionVarList(*this, getInReductionSyms(),
3132 getInReductionVars(),
3133 getInReductionByref());
3134}
3135
3136//===----------------------------------------------------------------------===//
3137// TaskgroupOp
3138//===----------------------------------------------------------------------===//
3139
3140void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3141 const TaskgroupOperands &clauses) {
3142 MLIRContext *ctx = builder.getContext();
3143 TaskgroupOp::build(builder, state, clauses.allocateVars,
3144 clauses.allocatorVars, clauses.taskReductionVars,
3145 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3146 makeArrayAttr(ctx, clauses.taskReductionSyms));
3147}
3148
3149LogicalResult TaskgroupOp::verify() {
3150 return verifyReductionVarList(*this, getTaskReductionSyms(),
3151 getTaskReductionVars(),
3152 getTaskReductionByref());
3153}
3154
3155//===----------------------------------------------------------------------===//
3156// TaskloopOp
3157//===----------------------------------------------------------------------===//
3158
3159void TaskloopOp::build(OpBuilder &builder, OperationState &state,
3160 const TaskloopOperands &clauses) {
3161 MLIRContext *ctx = builder.getContext();
3162 TaskloopOp::build(
3163 builder, state, clauses.allocateVars, clauses.allocatorVars,
3164 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3165 clauses.inReductionVars,
3166 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3167 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3168 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3169 /*private_vars=*/clauses.privateVars,
3170 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3171 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3172 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3173 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3174}
3175
3176LogicalResult TaskloopOp::verify() {
3177 if (getAllocateVars().size() != getAllocatorVars().size())
3178 return emitError(
3179 "expected equal sizes for allocate and allocator variables");
3180 if (failed(verifyReductionVarList(*this, getReductionSyms(),
3181 getReductionVars(), getReductionByref())) ||
3182 failed(verifyReductionVarList(*this, getInReductionSyms(),
3183 getInReductionVars(),
3184 getInReductionByref())))
3185 return failure();
3186
3187 if (!getReductionVars().empty() && getNogroup())
3188 return emitError("if a reduction clause is present on the taskloop "
3189 "directive, the nogroup clause must not be specified");
3190 for (auto var : getReductionVars()) {
3191 if (llvm::is_contained(getInReductionVars(), var))
3192 return emitError("the same list item cannot appear in both a reduction "
3193 "and an in_reduction clause");
3194 }
3195
3196 if (getGrainsize() && getNumTasks()) {
3197 return emitError(
3198 "the grainsize clause and num_tasks clause are mutually exclusive and "
3199 "may not appear on the same taskloop directive");
3200 }
3201
3202 return success();
3203}
3204
3205LogicalResult TaskloopOp::verifyRegions() {
3206 if (LoopWrapperInterface nested = getNestedWrapper()) {
3207 if (!isComposite())
3208 return emitError()
3209 << "'omp.composite' attribute missing from composite wrapper";
3210
3211 // Check for the allowed leaf constructs that may appear in a composite
3212 // construct directly after TASKLOOP.
3213 if (!isa<SimdOp>(nested))
3214 return emitError() << "only supported nested wrapper is 'omp.simd'";
3215 } else if (isComposite()) {
3216 return emitError()
3217 << "'omp.composite' attribute present in non-composite wrapper";
3218 }
3219
3220 return success();
3221}
3222
3223//===----------------------------------------------------------------------===//
3224// LoopNestOp
3225//===----------------------------------------------------------------------===//
3226
3227ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3228 // Parse an opening `(` followed by induction variables followed by `)`
3231 Type loopVarType;
3233 parser.parseColonType(loopVarType) ||
3234 // Parse loop bounds.
3235 parser.parseEqual() ||
3236 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3237 parser.parseKeyword("to") ||
3238 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3239 return failure();
3240
3241 for (auto &iv : ivs)
3242 iv.type = loopVarType;
3243
3244 auto *ctx = parser.getBuilder().getContext();
3245 // Parse "inclusive" flag.
3246 if (succeeded(parser.parseOptionalKeyword("inclusive")))
3247 result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3248
3249 // Parse step values.
3251 if (parser.parseKeyword("step") ||
3252 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3253 return failure();
3254
3255 // Parse collapse
3256 int64_t value = 0;
3257 if (!parser.parseOptionalKeyword("collapse") &&
3258 (parser.parseLParen() || parser.parseInteger(value) ||
3259 parser.parseRParen()))
3260 return failure();
3261 if (value > 1)
3262 result.addAttribute(
3263 "collapse_num_loops",
3264 IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3265
3266 // Parse tiles
3268 auto parseTiles = [&]() -> ParseResult {
3269 int64_t tile;
3270 if (parser.parseInteger(tile))
3271 return failure();
3272 tiles.push_back(tile);
3273 return success();
3274 };
3275
3276 if (!parser.parseOptionalKeyword("tiles") &&
3277 (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3278 parser.parseRParen()))
3279 return failure();
3280
3281 if (tiles.size() > 0)
3282 result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3283
3284 // Parse the body.
3285 Region *region = result.addRegion();
3286 if (parser.parseRegion(*region, ivs))
3287 return failure();
3288
3289 // Resolve operands.
3290 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3291 parser.resolveOperands(ubs, loopVarType, result.operands) ||
3292 parser.resolveOperands(steps, loopVarType, result.operands))
3293 return failure();
3294
3295 // Parse the optional attribute list.
3296 return parser.parseOptionalAttrDict(result.attributes);
3297}
3298
3299void LoopNestOp::print(OpAsmPrinter &p) {
3300 Region &region = getRegion();
3301 auto args = region.getArguments();
3302 p << " (" << args << ") : " << args[0].getType() << " = ("
3303 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3304 if (getLoopInclusive())
3305 p << "inclusive ";
3306 p << "step (" << getLoopSteps() << ") ";
3307 if (int64_t numCollapse = getCollapseNumLoops())
3308 if (numCollapse > 1)
3309 p << "collapse(" << numCollapse << ") ";
3310
3311 if (const auto tiles = getTileSizes())
3312 p << "tiles(" << tiles.value() << ") ";
3313
3314 p.printRegion(region, /*printEntryBlockArgs=*/false);
3315}
3316
3317void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3318 const LoopNestOperands &clauses) {
3319 MLIRContext *ctx = builder.getContext();
3320 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3321 clauses.loopLowerBounds, clauses.loopUpperBounds,
3322 clauses.loopSteps, clauses.loopInclusive,
3323 makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3324}
3325
3326LogicalResult LoopNestOp::verify() {
3327 if (getLoopLowerBounds().empty())
3328 return emitOpError() << "must represent at least one loop";
3329
3330 if (getLoopLowerBounds().size() != getIVs().size())
3331 return emitOpError() << "number of range arguments and IVs do not match";
3332
3333 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3334 if (lb.getType() != iv.getType())
3335 return emitOpError()
3336 << "range argument type does not match corresponding IV type";
3337 }
3338
3339 uint64_t numIVs = getIVs().size();
3340
3341 if (const auto &numCollapse = getCollapseNumLoops())
3342 if (numCollapse > numIVs)
3343 return emitOpError()
3344 << "collapse value is larger than the number of loops";
3345
3346 if (const auto &tiles = getTileSizes())
3347 if (tiles.value().size() > numIVs)
3348 return emitOpError() << "too few canonical loops for tile dimensions";
3349
3350 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3351 return emitOpError() << "expects parent op to be a loop wrapper";
3352
3353 return success();
3354}
3355
3356void LoopNestOp::gatherWrappers(
3358 Operation *parent = (*this)->getParentOp();
3359 while (auto wrapper =
3360 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3361 wrappers.push_back(wrapper);
3362 parent = parent->getParentOp();
3363 }
3364}
3365
3366//===----------------------------------------------------------------------===//
3367// OpenMP canonical loop handling
3368//===----------------------------------------------------------------------===//
3369
3370std::tuple<NewCliOp, OpOperand *, OpOperand *>
3371mlir::omp ::decodeCli(Value cli) {
3372
3373 // Defining a CLI for a generated loop is optional; if there is none then
3374 // there is no followup-tranformation
3375 if (!cli)
3376 return {{}, nullptr, nullptr};
3377
3378 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3379 "Unexpected type of cli");
3380
3381 NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3382 OpOperand *gen = nullptr;
3383 OpOperand *cons = nullptr;
3384 for (OpOperand &use : cli.getUses()) {
3385 auto op = cast<LoopTransformationInterface>(use.getOwner());
3386
3387 unsigned opnum = use.getOperandNumber();
3388 if (op.isGeneratee(opnum)) {
3389 assert(!gen && "Each CLI may have at most one def");
3390 gen = &use;
3391 } else if (op.isApplyee(opnum)) {
3392 assert(!cons && "Each CLI may have at most one consumer");
3393 cons = &use;
3394 } else {
3395 llvm_unreachable("Unexpected operand for a CLI");
3396 }
3397 }
3398
3399 return {create, gen, cons};
3400}
3401
3402void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3403 ::mlir::OperationState &odsState) {
3404 odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3405}
3406
3407void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3408 Value result = getResult();
3409 auto [newCli, gen, cons] = decodeCli(result);
3410
3411 // Structured binding `gen` cannot be captured in lambdas before C++20
3412 OpOperand *generator = gen;
3413
3414 // Derive the CLI variable name from its generator:
3415 // * "canonloop" for omp.canonical_loop
3416 // * custom name for loop transformation generatees
3417 // * "cli" as fallback if no generator
3418 // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3419 // at that level
3420 // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3421 // the index of that region
3422 std::string cliName{"cli"};
3423 if (gen) {
3424 cliName =
3426 .Case([&](CanonicalLoopOp op) {
3427 return generateLoopNestingName("canonloop", op);
3428 })
3429 .Case([&](UnrollHeuristicOp op) -> std::string {
3430 llvm_unreachable("heuristic unrolling does not generate a loop");
3431 })
3432 .Case([&](TileOp op) -> std::string {
3433 auto [generateesFirst, generateesCount] =
3434 op.getGenerateesODSOperandIndexAndLength();
3435 unsigned firstGrid = generateesFirst;
3436 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3437 unsigned end = generateesFirst + generateesCount;
3438 unsigned opnum = generator->getOperandNumber();
3439 // In the OpenMP apply and looprange clauses, indices are 1-based
3440 if (firstGrid <= opnum && opnum < firstIntratile) {
3441 unsigned gridnum = opnum - firstGrid + 1;
3442 return ("grid" + Twine(gridnum)).str();
3443 }
3444 if (firstIntratile <= opnum && opnum < end) {
3445 unsigned intratilenum = opnum - firstIntratile + 1;
3446 return ("intratile" + Twine(intratilenum)).str();
3447 }
3448 llvm_unreachable("Unexpected generatee argument");
3449 })
3450 .DefaultUnreachable("TODO: Custom name for this operation");
3451 }
3452
3453 setNameFn(result, cliName);
3454}
3455
3456LogicalResult NewCliOp::verify() {
3457 Value cli = getResult();
3458
3459 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3460 "Unexpected type of cli");
3461
3462 // Check that the CLI is used in at most generator and one consumer
3463 OpOperand *gen = nullptr;
3464 OpOperand *cons = nullptr;
3465 for (mlir::OpOperand &use : cli.getUses()) {
3466 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3467
3468 unsigned opnum = use.getOperandNumber();
3469 if (op.isGeneratee(opnum)) {
3470 if (gen) {
3471 InFlightDiagnostic error =
3472 emitOpError("CLI must have at most one generator");
3473 error.attachNote(gen->getOwner()->getLoc())
3474 .append("first generator here:");
3475 error.attachNote(use.getOwner()->getLoc())
3476 .append("second generator here:");
3477 return error;
3478 }
3479
3480 gen = &use;
3481 } else if (op.isApplyee(opnum)) {
3482 if (cons) {
3483 InFlightDiagnostic error =
3484 emitOpError("CLI must have at most one consumer");
3485 error.attachNote(cons->getOwner()->getLoc())
3486 .append("first consumer here:")
3487 .appendOp(*cons->getOwner(),
3488 OpPrintingFlags().printGenericOpForm());
3489 error.attachNote(use.getOwner()->getLoc())
3490 .append("second consumer here:")
3491 .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
3492 return error;
3493 }
3494
3495 cons = &use;
3496 } else {
3497 llvm_unreachable("Unexpected operand for a CLI");
3498 }
3499 }
3500
3501 // If the CLI is source of a transformation, it must have a generator
3502 if (cons && !gen) {
3503 InFlightDiagnostic error = emitOpError("CLI has no generator");
3504 error.attachNote(cons->getOwner()->getLoc())
3505 .append("see consumer here: ")
3506 .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
3507 return error;
3508 }
3509
3510 return success();
3511}
3512
3513void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3514 Value tripCount) {
3515 odsState.addOperands(tripCount);
3516 odsState.addOperands(Value());
3517 (void)odsState.addRegion();
3518}
3519
3520void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3521 Value tripCount, ::mlir::Value cli) {
3522 odsState.addOperands(tripCount);
3523 odsState.addOperands(cli);
3524 (void)odsState.addRegion();
3525}
3526
3527void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3528 setNameFn(&getRegion().front(), "body_entry");
3529}
3530
3531void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3532 OpAsmSetValueNameFn setNameFn) {
3533 std::string ivName = generateLoopNestingName("iv", *this);
3534 setNameFn(region.getArgument(0), ivName);
3535}
3536
3537void CanonicalLoopOp::print(OpAsmPrinter &p) {
3538 if (getCli())
3539 p << '(' << getCli() << ')';
3540 p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3541 << " in range(" << getTripCount() << ") ";
3542
3543 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3544 /*printBlockTerminators=*/true);
3545
3546 p.printOptionalAttrDict((*this)->getAttrs());
3547}
3548
3549mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3551 CanonicalLoopInfoType cliType =
3552 CanonicalLoopInfoType::get(parser.getContext());
3553
3554 // Parse (optional) omp.cli identifier
3556 SmallVector<mlir::Value, 1> cliOperand;
3557 if (!parser.parseOptionalLParen()) {
3558 if (parser.parseOperand(cli) ||
3559 parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
3560 return failure();
3561 }
3562
3563 // We derive the type of tripCount from inductionVariable. MLIR requires the
3564 // type of tripCount to be known when calling resolveOperand so we have parse
3565 // the type before processing the inductionVariable.
3566 OpAsmParser::Argument inductionVariable;
3568 if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
3569 parser.parseKeyword("in") || parser.parseKeyword("range") ||
3570 parser.parseLParen() || parser.parseOperand(tripcount) ||
3571 parser.parseRParen() ||
3572 parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
3573 return failure();
3574
3575 // Parse the loop body.
3576 Region *region = result.addRegion();
3577 if (parser.parseRegion(*region, {inductionVariable}))
3578 return failure();
3579
3580 // We parsed the cli operand forst, but because it is optional, it must be
3581 // last in the operand list.
3582 result.operands.append(cliOperand);
3583
3584 // Parse the optional attribute list.
3585 if (parser.parseOptionalAttrDict(result.attributes))
3586 return failure();
3587
3588 return mlir::success();
3589}
3590
3591LogicalResult CanonicalLoopOp::verify() {
3592 // The region's entry must accept the induction variable
3593 // It can also be empty if just created
3594 if (!getRegion().empty()) {
3595 Region &region = getRegion();
3596 if (region.getNumArguments() != 1)
3597 return emitOpError(
3598 "Canonical loop region must have exactly one argument");
3599
3600 if (getInductionVar().getType() != getTripCount().getType())
3601 return emitOpError(
3602 "Region argument must be the same type as the trip count");
3603 }
3604
3605 return success();
3606}
3607
3608Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
3609
3610std::pair<unsigned, unsigned>
3611CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3612 // No applyees
3613 return {0, 0};
3614}
3615
3616std::pair<unsigned, unsigned>
3617CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3618 return getODSOperandIndexAndLength(odsIndex_cli);
3619}
3620
3621//===----------------------------------------------------------------------===//
3622// UnrollHeuristicOp
3623//===----------------------------------------------------------------------===//
3624
3625void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3626 ::mlir::OperationState &odsState,
3627 ::mlir::Value cli) {
3628 odsState.addOperands(cli);
3629}
3630
3631void UnrollHeuristicOp::print(OpAsmPrinter &p) {
3632 p << '(' << getApplyee() << ')';
3633
3634 p.printOptionalAttrDict((*this)->getAttrs());
3635}
3636
3637mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3639 auto cliType = CanonicalLoopInfoType::get(parser.getContext());
3640
3641 if (parser.parseLParen())
3642 return failure();
3643
3645 if (parser.parseOperand(applyee) ||
3646 parser.resolveOperand(applyee, cliType, result.operands))
3647 return failure();
3648
3649 if (parser.parseRParen())
3650 return failure();
3651
3652 // Optional output loop (full unrolling has none)
3653 if (!parser.parseOptionalArrow()) {
3654 if (parser.parseLParen() || parser.parseRParen())
3655 return failure();
3656 }
3657
3658 // Parse the optional attribute list.
3659 if (parser.parseOptionalAttrDict(result.attributes))
3660 return failure();
3661
3662 return mlir::success();
3663}
3664
3665std::pair<unsigned, unsigned>
3666UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3667 return getODSOperandIndexAndLength(odsIndex_applyee);
3668}
3669
3670std::pair<unsigned, unsigned>
3671UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3672 return {0, 0};
3673}
3674
3675//===----------------------------------------------------------------------===//
3676// TileOp
3677//===----------------------------------------------------------------------===//
3678
3679static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
3680 OperandRange generatees,
3681 OperandRange applyees) {
3682 if (!generatees.empty())
3683 p << '(' << llvm::interleaved(generatees) << ')';
3684
3685 if (!applyees.empty())
3686 p << " <- (" << llvm::interleaved(applyees) << ')';
3687}
3688
3689static ParseResult parseLoopTransformClis(
3690 OpAsmParser &parser,
3693 if (parser.parseOptionalLess()) {
3694 // Syntax 1: generatees present
3695
3696 if (parser.parseOperandList(generateesOperands,
3698 return failure();
3699
3700 if (parser.parseLess())
3701 return failure();
3702 } else {
3703 // Syntax 2: generatees omitted
3704 }
3705
3706 // Parse `<-` (`<` has already been parsed)
3707 if (parser.parseMinus())
3708 return failure();
3709
3710 if (parser.parseOperandList(applyeesOperands,
3712 return failure();
3713
3714 return success();
3715}
3716
3717LogicalResult TileOp::verify() {
3718 if (getApplyees().empty())
3719 return emitOpError() << "must apply to at least one loop";
3720
3721 if (getSizes().size() != getApplyees().size())
3722 return emitOpError() << "there must be one tile size for each applyee";
3723
3724 if (!getGeneratees().empty() &&
3725 2 * getSizes().size() != getGeneratees().size())
3726 return emitOpError()
3727 << "expecting two times the number of generatees than applyees";
3728
3729 DenseSet<Value> parentIVs;
3730
3731 Value parent = getApplyees().front();
3732 for (auto &&applyee : llvm::drop_begin(getApplyees())) {
3733 auto [parentCreate, parentGen, parentCons] = decodeCli(parent);
3734 auto [create, gen, cons] = decodeCli(applyee);
3735
3736 if (!parentGen)
3737 return emitOpError() << "applyee CLI has no generator";
3738
3739 auto parentLoop = dyn_cast_or_null<CanonicalLoopOp>(parentGen->getOwner());
3740 if (!parentGen)
3741 return emitOpError()
3742 << "currently only supports omp.canonical_loop as applyee";
3743
3744 parentIVs.insert(parentLoop.getInductionVar());
3745
3746 if (!gen)
3747 return emitOpError() << "applyee CLI has no generator";
3748 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
3749 if (!loop)
3750 return emitOpError()
3751 << "currently only supports omp.canonical_loop as applyee";
3752
3753 // Canonical loop must be perfectly nested, i.e. the body of the parent must
3754 // only contain the omp.canonical_loop of the nested loops, and
3755 // omp.terminator
3756 bool isPerfectlyNested = [&]() {
3757 auto &parentBody = parentLoop.getRegion();
3758 if (!parentBody.hasOneBlock())
3759 return false;
3760 auto &parentBlock = parentBody.getBlocks().front();
3761
3762 auto nestedLoopIt = parentBlock.begin();
3763 if (nestedLoopIt == parentBlock.end() ||
3764 (&*nestedLoopIt != loop.getOperation()))
3765 return false;
3766
3767 auto termIt = std::next(nestedLoopIt);
3768 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
3769 return false;
3770
3771 if (std::next(termIt) != parentBlock.end())
3772 return false;
3773
3774 return true;
3775 }();
3776 if (!isPerfectlyNested)
3777 return emitOpError() << "tiled loop nest must be perfectly nested";
3778
3779 if (parentIVs.contains(loop.getTripCount()))
3780 return emitOpError() << "tiled loop nest must be rectangular";
3781
3782 parent = applyee;
3783 }
3784
3785 // TODO: The tile sizes must be computed before the loop, but checking this
3786 // requires dominance analysis. For instance:
3787 //
3788 // %canonloop = omp.new_cli
3789 // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
3790 // // write to %x
3791 // omp.terminator
3792 // }
3793 // %ts = llvm.load %x
3794 // omp.tile <- (%canonloop) sizes(%ts : i32)
3795
3796 return success();
3797}
3798
3799std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
3800 return getODSOperandIndexAndLength(odsIndex_applyees);
3801}
3802
3803std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
3804 return getODSOperandIndexAndLength(odsIndex_generatees);
3805}
3806
3807//===----------------------------------------------------------------------===//
3808// Critical construct (2.17.1)
3809//===----------------------------------------------------------------------===//
3810
3811void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
3812 const CriticalDeclareOperands &clauses) {
3813 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
3814}
3815
3816LogicalResult CriticalDeclareOp::verify() {
3817 return verifySynchronizationHint(*this, getHint());
3818}
3819
3820LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
3821 if (getNameAttr()) {
3822 SymbolRefAttr symbolRef = getNameAttr();
3823 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
3824 *this, symbolRef);
3825 if (!decl) {
3826 return emitOpError() << "expected symbol reference " << symbolRef
3827 << " to point to a critical declaration";
3828 }
3829 }
3830
3831 return success();
3832}
3833
3834//===----------------------------------------------------------------------===//
3835// Ordered construct
3836//===----------------------------------------------------------------------===//
3837
3838static LogicalResult verifyOrderedParent(Operation &op) {
3839 bool hasRegion = op.getNumRegions() > 0;
3840 auto loopOp = op.getParentOfType<LoopNestOp>();
3841 if (!loopOp) {
3842 if (hasRegion)
3843 return success();
3844
3845 // TODO: Consider if this needs to be the case only for the standalone
3846 // variant of the ordered construct.
3847 return op.emitOpError() << "must be nested inside of a loop";
3848 }
3849
3850 Operation *wrapper = loopOp->getParentOp();
3851 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
3852 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
3853 if (!orderedAttr)
3854 return op.emitOpError() << "the enclosing worksharing-loop region must "
3855 "have an ordered clause";
3856
3857 if (hasRegion && orderedAttr.getInt() != 0)
3858 return op.emitOpError() << "the enclosing loop's ordered clause must not "
3859 "have a parameter present";
3860
3861 if (!hasRegion && orderedAttr.getInt() == 0)
3862 return op.emitOpError() << "the enclosing loop's ordered clause must "
3863 "have a parameter present";
3864 } else if (!isa<SimdOp>(wrapper)) {
3865 return op.emitOpError() << "must be nested inside of a worksharing, simd "
3866 "or worksharing simd loop";
3867 }
3868 return success();
3869}
3870
3871void OrderedOp::build(OpBuilder &builder, OperationState &state,
3872 const OrderedOperands &clauses) {
3873 OrderedOp::build(builder, state, clauses.doacrossDependType,
3874 clauses.doacrossNumLoops, clauses.doacrossDependVars);
3875}
3876
3877LogicalResult OrderedOp::verify() {
3878 if (failed(verifyOrderedParent(**this)))
3879 return failure();
3880
3881 auto wrapper = (*this)->getParentOfType<WsloopOp>();
3882 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
3883 return emitOpError() << "number of variables in depend clause does not "
3884 << "match number of iteration variables in the "
3885 << "doacross loop";
3886
3887 return success();
3888}
3889
3890void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
3891 const OrderedRegionOperands &clauses) {
3892 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
3893}
3894
3895LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
3896
3897//===----------------------------------------------------------------------===//
3898// TaskwaitOp
3899//===----------------------------------------------------------------------===//
3900
3901void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
3902 const TaskwaitOperands &clauses) {
3903 // TODO Store clauses in op: dependKinds, dependVars, nowait.
3904 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
3905 /*depend_vars=*/{}, /*nowait=*/nullptr);
3906}
3907
3908//===----------------------------------------------------------------------===//
3909// Verifier for AtomicReadOp
3910//===----------------------------------------------------------------------===//
3911
3912LogicalResult AtomicReadOp::verify() {
3913 if (verifyCommon().failed())
3914 return mlir::failure();
3915
3916 if (auto mo = getMemoryOrder()) {
3917 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3918 *mo == ClauseMemoryOrderKind::Release) {
3919 return emitError(
3920 "memory-order must not be acq_rel or release for atomic reads");
3921 }
3922 }
3923 return verifySynchronizationHint(*this, getHint());
3924}
3925
3926//===----------------------------------------------------------------------===//
3927// Verifier for AtomicWriteOp
3928//===----------------------------------------------------------------------===//
3929
3930LogicalResult AtomicWriteOp::verify() {
3931 if (verifyCommon().failed())
3932 return mlir::failure();
3933
3934 if (auto mo = getMemoryOrder()) {
3935 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3936 *mo == ClauseMemoryOrderKind::Acquire) {
3937 return emitError(
3938 "memory-order must not be acq_rel or acquire for atomic writes");
3939 }
3940 }
3941 return verifySynchronizationHint(*this, getHint());
3942}
3943
3944//===----------------------------------------------------------------------===//
3945// Verifier for AtomicUpdateOp
3946//===----------------------------------------------------------------------===//
3947
3948LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
3949 PatternRewriter &rewriter) {
3950 if (op.isNoOp()) {
3951 rewriter.eraseOp(op);
3952 return success();
3953 }
3954 if (Value writeVal = op.getWriteOpVal()) {
3955 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
3956 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
3957 return success();
3958 }
3959 return failure();
3960}
3961
3962LogicalResult AtomicUpdateOp::verify() {
3963 if (verifyCommon().failed())
3964 return mlir::failure();
3965
3966 if (auto mo = getMemoryOrder()) {
3967 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
3968 *mo == ClauseMemoryOrderKind::Acquire) {
3969 return emitError(
3970 "memory-order must not be acq_rel or acquire for atomic updates");
3971 }
3972 }
3973
3974 return verifySynchronizationHint(*this, getHint());
3975}
3976
3977LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
3978
3979//===----------------------------------------------------------------------===//
3980// Verifier for AtomicCaptureOp
3981//===----------------------------------------------------------------------===//
3982
3983AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
3984 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
3985 return op;
3986 return dyn_cast<AtomicReadOp>(getSecondOp());
3987}
3988
3989AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
3990 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
3991 return op;
3992 return dyn_cast<AtomicWriteOp>(getSecondOp());
3993}
3994
3995AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
3996 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
3997 return op;
3998 return dyn_cast<AtomicUpdateOp>(getSecondOp());
3999}
4000
4001LogicalResult AtomicCaptureOp::verify() {
4002 return verifySynchronizationHint(*this, getHint());
4003}
4004
4005LogicalResult AtomicCaptureOp::verifyRegions() {
4006 if (verifyRegionsCommon().failed())
4007 return mlir::failure();
4008
4009 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
4010 return emitOpError(
4011 "operations inside capture region must not have hint clause");
4012
4013 if (getFirstOp()->getAttr("memory_order") ||
4014 getSecondOp()->getAttr("memory_order"))
4015 return emitOpError(
4016 "operations inside capture region must not have memory_order clause");
4017 return success();
4018}
4019
4020//===----------------------------------------------------------------------===//
4021// CancelOp
4022//===----------------------------------------------------------------------===//
4023
4024void CancelOp::build(OpBuilder &builder, OperationState &state,
4025 const CancelOperands &clauses) {
4026 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4027}
4028
4030 Operation *parent = thisOp->getParentOp();
4031 while (parent) {
4032 if (parent->getDialect() == thisOp->getDialect())
4033 return parent;
4034 parent = parent->getParentOp();
4035 }
4036 return nullptr;
4037}
4038
4039LogicalResult CancelOp::verify() {
4040 ClauseCancellationConstructType cct = getCancelDirective();
4041 // The next OpenMP operation in the chain of parents
4042 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4043 if (!structuralParent)
4044 return emitOpError() << "Orphaned cancel construct";
4045
4046 if ((cct == ClauseCancellationConstructType::Parallel) &&
4047 !mlir::isa<ParallelOp>(structuralParent)) {
4048 return emitOpError() << "cancel parallel must appear "
4049 << "inside a parallel region";
4050 }
4051 if (cct == ClauseCancellationConstructType::Loop) {
4052 // structural parent will be omp.loop_nest, directly nested inside
4053 // omp.wsloop
4054 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4055
4056 if (!wsloopOp) {
4057 return emitOpError()
4058 << "cancel loop must appear inside a worksharing-loop region";
4059 }
4060 if (wsloopOp.getNowaitAttr()) {
4061 return emitError() << "A worksharing construct that is canceled "
4062 << "must not have a nowait clause";
4063 }
4064 if (wsloopOp.getOrderedAttr()) {
4065 return emitError() << "A worksharing construct that is canceled "
4066 << "must not have an ordered clause";
4067 }
4068
4069 } else if (cct == ClauseCancellationConstructType::Sections) {
4070 // structural parent will be an omp.section, directly nested inside
4071 // omp.sections
4072 auto sectionsOp =
4073 mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4074 if (!sectionsOp) {
4075 return emitOpError() << "cancel sections must appear "
4076 << "inside a sections region";
4077 }
4078 if (sectionsOp.getNowait()) {
4079 return emitError() << "A sections construct that is canceled "
4080 << "must not have a nowait clause";
4081 }
4082 }
4083 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4084 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4085 !mlir::isa<omp::TaskloopOp>(structuralParent->getParentOp()))) {
4086 return emitOpError() << "cancel taskgroup must appear "
4087 << "inside a task region";
4088 }
4089 return success();
4090}
4091
4092//===----------------------------------------------------------------------===//
4093// CancellationPointOp
4094//===----------------------------------------------------------------------===//
4095
4096void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4097 const CancellationPointOperands &clauses) {
4098 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4099}
4100
4101LogicalResult CancellationPointOp::verify() {
4102 ClauseCancellationConstructType cct = getCancelDirective();
4103 // The next OpenMP operation in the chain of parents
4104 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4105 if (!structuralParent)
4106 return emitOpError() << "Orphaned cancellation point";
4107
4108 if ((cct == ClauseCancellationConstructType::Parallel) &&
4109 !mlir::isa<ParallelOp>(structuralParent)) {
4110 return emitOpError() << "cancellation point parallel must appear "
4111 << "inside a parallel region";
4112 }
4113 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4114 // find the wsloop
4115 if ((cct == ClauseCancellationConstructType::Loop) &&
4116 !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4117 return emitOpError() << "cancellation point loop must appear "
4118 << "inside a worksharing-loop region";
4119 }
4120 if ((cct == ClauseCancellationConstructType::Sections) &&
4121 !mlir::isa<omp::SectionOp>(structuralParent)) {
4122 return emitOpError() << "cancellation point sections must appear "
4123 << "inside a sections region";
4124 }
4125 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4126 !mlir::isa<omp::TaskOp>(structuralParent)) {
4127 return emitOpError() << "cancellation point taskgroup must appear "
4128 << "inside a task region";
4129 }
4130 return success();
4131}
4132
4133//===----------------------------------------------------------------------===//
4134// MapBoundsOp
4135//===----------------------------------------------------------------------===//
4136
4137LogicalResult MapBoundsOp::verify() {
4138 auto extent = getExtent();
4139 auto upperbound = getUpperBound();
4140 if (!extent && !upperbound)
4141 return emitError("expected extent or upperbound.");
4142 return success();
4143}
4144
4145void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4146 TypeRange /*result_types*/, StringAttr symName,
4147 TypeAttr type) {
4148 PrivateClauseOp::build(
4149 odsBuilder, odsState, symName, type,
4150 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
4151 DataSharingClauseType::Private));
4152}
4153
4154LogicalResult PrivateClauseOp::verifyRegions() {
4155 Type argType = getArgType();
4156 auto verifyTerminator = [&](Operation *terminator,
4157 bool yieldsValue) -> LogicalResult {
4158 if (!terminator->getBlock()->getSuccessors().empty())
4159 return success();
4160
4161 if (!llvm::isa<YieldOp>(terminator))
4162 return mlir::emitError(terminator->getLoc())
4163 << "expected exit block terminator to be an `omp.yield` op.";
4164
4165 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4166 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4167
4168 if (!yieldsValue) {
4169 if (yieldedTypes.empty())
4170 return success();
4171
4172 return mlir::emitError(terminator->getLoc())
4173 << "Did not expect any values to be yielded.";
4174 }
4175
4176 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4177 return success();
4178
4179 auto error = mlir::emitError(yieldOp.getLoc())
4180 << "Invalid yielded value. Expected type: " << argType
4181 << ", got: ";
4182
4183 if (yieldedTypes.empty())
4184 error << "None";
4185 else
4186 error << yieldedTypes;
4187
4188 return error;
4189 };
4190
4191 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4192 StringRef regionName,
4193 bool yieldsValue) -> LogicalResult {
4194 assert(!region.empty());
4195
4196 if (region.getNumArguments() != expectedNumArgs)
4197 return mlir::emitError(region.getLoc())
4198 << "`" << regionName << "`: "
4199 << "expected " << expectedNumArgs
4200 << " region arguments, got: " << region.getNumArguments();
4201
4202 for (Block &block : region) {
4203 // MLIR will verify the absence of the terminator for us.
4204 if (!block.mightHaveTerminator())
4205 continue;
4206
4207 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4208 return failure();
4209 }
4210
4211 return success();
4212 };
4213
4214 // Ensure all of the region arguments have the same type
4215 for (Region *region : getRegions())
4216 for (Type ty : region->getArgumentTypes())
4217 if (ty != argType)
4218 return emitError() << "Region argument type mismatch: got " << ty
4219 << " expected " << argType << ".";
4220
4221 mlir::Region &initRegion = getInitRegion();
4222 if (!initRegion.empty() &&
4223 failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4224 /*yieldsValue=*/true)))
4225 return failure();
4226
4227 DataSharingClauseType dsType = getDataSharingType();
4228
4229 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4230 return emitError("`private` clauses do not require a `copy` region.");
4231
4232 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4233 return emitError(
4234 "`firstprivate` clauses require at least a `copy` region.");
4235
4236 if (dsType == DataSharingClauseType::FirstPrivate &&
4237 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4238 /*yieldsValue=*/true)))
4239 return failure();
4240
4241 if (!getDeallocRegion().empty() &&
4242 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4243 /*yieldsValue=*/false)))
4244 return failure();
4245
4246 return success();
4247}
4248
4249//===----------------------------------------------------------------------===//
4250// Spec 5.2: Masked construct (10.5)
4251//===----------------------------------------------------------------------===//
4252
4253void MaskedOp::build(OpBuilder &builder, OperationState &state,
4254 const MaskedOperands &clauses) {
4255 MaskedOp::build(builder, state, clauses.filteredThreadId);
4256}
4257
4258//===----------------------------------------------------------------------===//
4259// Spec 5.2: Scan construct (5.6)
4260//===----------------------------------------------------------------------===//
4261
4262void ScanOp::build(OpBuilder &builder, OperationState &state,
4263 const ScanOperands &clauses) {
4264 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4265}
4266
4267LogicalResult ScanOp::verify() {
4268 if (hasExclusiveVars() == hasInclusiveVars())
4269 return emitError(
4270 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4271 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4272 if (parentWsLoopOp.getReductionModAttr() &&
4273 parentWsLoopOp.getReductionModAttr().getValue() ==
4274 ReductionModifier::inscan)
4275 return success();
4276 }
4277 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4278 if (parentSimdOp.getReductionModAttr() &&
4279 parentSimdOp.getReductionModAttr().getValue() ==
4280 ReductionModifier::inscan)
4281 return success();
4282 }
4283 return emitError("SCAN directive needs to be enclosed within a parent "
4284 "worksharing loop construct or SIMD construct with INSCAN "
4285 "reduction modifier");
4286}
4287
4288/// Verifies align clause in allocate directive
4289
4290LogicalResult AllocateDirOp::verify() {
4291 std::optional<uint64_t> align = this->getAlign();
4292
4293 if (align.has_value()) {
4294 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4295 return emitError() << "ALIGN value : " << align.value()
4296 << " must be power of 2";
4297 }
4298
4299 return success();
4300}
4301
4302//===----------------------------------------------------------------------===//
4303// TargetAllocMemOp
4304//===----------------------------------------------------------------------===//
4305
4306mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4307 return getInTypeAttr().getValue();
4308}
4309
4310/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
4311/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
4312/// attr-dict-without-keyword
4313static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
4315 auto &builder = parser.getBuilder();
4316 bool hasOperands = false;
4317 std::int32_t typeparamsSize = 0;
4318
4319 // Parse device number as a new operand
4321 mlir::Type deviceType;
4322 if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
4323 return mlir::failure();
4324 if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
4325 return mlir::failure();
4326 if (parser.parseComma())
4327 return mlir::failure();
4328
4329 mlir::Type intype;
4330 if (parser.parseType(intype))
4331 return mlir::failure();
4332 result.addAttribute("in_type", mlir::TypeAttr::get(intype));
4335 if (!parser.parseOptionalLParen()) {
4336 // parse the LEN params of the derived type. (<params> : <types>)
4338 parser.parseColonTypeList(typeVec) || parser.parseRParen())
4339 return mlir::failure();
4340 typeparamsSize = operands.size();
4341 hasOperands = true;
4342 }
4343 std::int32_t shapeSize = 0;
4344 if (!parser.parseOptionalComma()) {
4345 // parse size to scale by, vector of n dimensions of type index
4347 return mlir::failure();
4348 shapeSize = operands.size() - typeparamsSize;
4349 auto idxTy = builder.getIndexType();
4350 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4351 typeVec.push_back(idxTy);
4352 hasOperands = true;
4353 }
4354 if (hasOperands &&
4355 parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
4356 result.operands))
4357 return mlir::failure();
4358
4359 mlir::Type restype = builder.getIntegerType(64);
4360 if (!restype) {
4361 parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
4362 return mlir::failure();
4363 }
4364 llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
4365 result.addAttribute("operandSegmentSizes",
4366 builder.getDenseI32ArrayAttr(segmentSizes));
4367 if (parser.parseOptionalAttrDict(result.attributes) ||
4368 parser.addTypeToList(restype, result.types))
4369 return mlir::failure();
4370 return mlir::success();
4371}
4372
4373mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
4375 return parseTargetAllocMemOp(parser, result);
4376}
4377
4378void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
4379 p << " ";
4381 p << " : ";
4382 p << getDevice().getType();
4383 p << ", ";
4384 p << getInType();
4385 if (!getTypeparams().empty()) {
4386 p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
4387 }
4388 for (auto sh : getShape()) {
4389 p << ", ";
4390 p.printOperand(sh);
4391 }
4392 p.printOptionalAttrDict((*this)->getAttrs(),
4393 {"in_type", "operandSegmentSizes"});
4394}
4395
4396llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4397 mlir::Type outType = getType();
4398 if (!mlir::dyn_cast<IntegerType>(outType))
4399 return emitOpError("must be a integer type");
4400 return mlir::success();
4401}
4402
4403//===----------------------------------------------------------------------===//
4404// WorkdistributeOp
4405//===----------------------------------------------------------------------===//
4406
4407LogicalResult WorkdistributeOp::verify() {
4408 // Check that region exists and is not empty
4409 Region &region = getRegion();
4410 if (region.empty())
4411 return emitOpError("region cannot be empty");
4412 // Verify single entry point.
4413 Block &entryBlock = region.front();
4414 if (entryBlock.empty())
4415 return emitOpError("region must contain a structured block");
4416 // Verify single exit point.
4417 bool hasTerminator = false;
4418 for (Block &block : region) {
4419 if (isa<TerminatorOp>(block.back())) {
4420 if (hasTerminator) {
4421 return emitOpError("region must have exactly one terminator");
4422 }
4423 hasTerminator = true;
4424 }
4425 }
4426 if (!hasTerminator) {
4427 return emitOpError("region must be terminated with omp.terminator");
4428 }
4429 auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4430 // No implicit barrier at end
4431 if (isa<BarrierOp>(op)) {
4432 return emitOpError(
4433 "explicit barriers are not allowed in workdistribute region");
4434 }
4435 // Check for invalid nested constructs
4436 if (isa<ParallelOp>(op)) {
4437 return emitOpError(
4438 "nested parallel constructs not allowed in workdistribute");
4439 }
4440 if (isa<TeamsOp>(op)) {
4441 return emitOpError(
4442 "nested teams constructs not allowed in workdistribute");
4443 }
4444 return WalkResult::advance();
4445 });
4446 if (walkResult.wasInterrupted())
4447 return failure();
4448
4449 Operation *parentOp = (*this)->getParentOp();
4450 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4451 return emitOpError("workdistribute must be nested under teams");
4452 return success();
4453}
4454
4455#define GET_ATTRDEF_CLASSES
4456#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
4457
4458#define GET_OP_CLASSES
4459#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
4460
4461#define GET_TYPEDEF_CLASSES
4462#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
static SmallVector< Value > getTileSizes(Location loc, amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1371
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars)
Print Linear Clause.
*attr dict without keyword *static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds)
Print Depend clause.
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
return success()
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars)
Verifies Depend clause.
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:149
bool empty()
Definition Block.h:148
BlockArgument getArgument(unsigned i)
Definition Block.h:129
unsigned getNumArguments()
Definition Block.h:128
Operation & front()
Definition Block.h:153
SuccessorRange getSuccessors()
Definition Block.h:270
BlockArgListType getArguments()
Definition Block.h:87
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:163
IntegerType getI64Type()
Definition Builders.cpp:65
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:67
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:51
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:98
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition Builders.h:207
This class represents an operand of an operation.
Definition Value.h:257
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:43
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:220
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:686
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:749
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:213
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:674
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:223
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:234
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:238
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:677
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:797
user_range getUsers()
Returns a range of all users.
Definition Operation.h:873
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:230
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:216
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition Region.h:124
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:37
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:387
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
SideEffects::EffectInstance< Effect > EffectInstance
bool isPerfectlyNested(ArrayRef< AffineForOp > loops)
Returns true if loops is a perfectly nested loop nest, where loops appear in it from outermost to inn...
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:561
Include the generated interface declarations.
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:304
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:128
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:144
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1293
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.