MLIR 23.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Matchers.h"
24#include "mlir/IR/SymbolTable.h"
27
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/STLForwardCompat.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringRef.h"
35#include "llvm/ADT/TypeSwitch.h"
36#include "llvm/ADT/bit.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cstddef>
39#include <iterator>
40#include <optional>
41#include <variant>
42
43#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
46#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
47
48using namespace mlir;
49using namespace mlir::omp;
50
53 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
54}
55
58 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
59}
60
63 return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
64}
65
66namespace {
67struct MemRefPointerLikeModel
68 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
69 MemRefType> {
70 Type getElementType(Type pointer) const {
71 return llvm::cast<MemRefType>(pointer).getElementType();
72 }
73};
74
75struct LLVMPointerPointerLikeModel
76 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77 LLVM::LLVMPointerType> {
78 Type getElementType(Type pointer) const { return Type(); }
79};
80} // namespace
81
82/// Generate a name of a canonical loop nest of the format
83/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
84/// argument index of an operation that has multiple regions, if the operation
85/// has multiple regions.
86/// `_s<idx>` identifies the position of an operation within a region, where
87/// only operations that may potentially contain loops ("container operations"
88/// i.e. have region arguments) are counted. Again, it is omitted if there is
89/// only one such operation in a region. If there are canonical loops nested
90/// inside each other, also may also use the format `_d<num>` where <num> is the
91/// nesting depth of the loop.
92///
93/// The generated name is a best-effort to make canonical loop unique within an
94/// SSA namespace. This also means that regions with IsolatedFromAbove property
95/// do not consider any parents or siblings.
96static std::string generateLoopNestingName(StringRef prefix,
97 CanonicalLoopOp op) {
98 struct Component {
99 /// If true, this component describes a region operand of an operation (the
100 /// operand's owner) If false, this component describes an operation located
101 /// in a parent region
102 bool isRegionArgOfOp;
103 bool skip = false;
104 bool isUnique = false;
105
106 size_t idx;
107 Operation *op;
108 Region *parentRegion;
109 size_t loopDepth;
110
111 Operation *&getOwnerOp() {
112 assert(isRegionArgOfOp && "Must describe a region operand");
113 return op;
114 }
115 size_t &getArgIdx() {
116 assert(isRegionArgOfOp && "Must describe a region operand");
117 return idx;
118 }
119
120 Operation *&getContainerOp() {
121 assert(!isRegionArgOfOp && "Must describe a operation of a region");
122 return op;
123 }
124 size_t &getOpPos() {
125 assert(!isRegionArgOfOp && "Must describe a operation of a region");
126 return idx;
127 }
128 bool isLoopOp() const {
129 assert(!isRegionArgOfOp && "Must describe a operation of a region");
130 return isa<CanonicalLoopOp>(op);
131 }
132 Region *&getParentRegion() {
133 assert(!isRegionArgOfOp && "Must describe a operation of a region");
134 return parentRegion;
135 }
136 size_t &getLoopDepth() {
137 assert(!isRegionArgOfOp && "Must describe a operation of a region");
138 return loopDepth;
139 }
140
141 void skipIf(bool v = true) { skip = skip || v; }
142 };
143
144 // List of ancestors, from inner to outer.
145 // Alternates between
146 // * region argument of an operation
147 // * operation within a region
148 SmallVector<Component> components;
149
150 // Gather a list of parent regions and operations, and the position within
151 // their parent
152 Operation *o = op.getOperation();
153 while (o) {
154 // Operation within a region
155 Region *r = o->getParentRegion();
156 if (!r)
157 break;
158
159 llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
160 size_t idx = 0;
161 bool found = false;
162 size_t sequentialIdx = -1;
163 bool isOnlyContainerOp = true;
164 for (Block *b : traversal) {
165 for (Operation &op : *b) {
166 if (&op == o && !found) {
167 sequentialIdx = idx;
168 found = true;
169 }
170 if (op.getNumRegions()) {
171 idx += 1;
172 if (idx > 1)
173 isOnlyContainerOp = false;
174 }
175 if (found && !isOnlyContainerOp)
176 break;
177 }
178 }
179
180 Component &containerOpInRegion = components.emplace_back();
181 containerOpInRegion.isRegionArgOfOp = false;
182 containerOpInRegion.isUnique = isOnlyContainerOp;
183 containerOpInRegion.getContainerOp() = o;
184 containerOpInRegion.getOpPos() = sequentialIdx;
185 containerOpInRegion.getParentRegion() = r;
186
187 Operation *parent = r->getParentOp();
188
189 // Region argument of an operation
190 Component &regionArgOfOperation = components.emplace_back();
191 regionArgOfOperation.isRegionArgOfOp = true;
192 regionArgOfOperation.isUnique = true;
193 regionArgOfOperation.getArgIdx() = 0;
194 regionArgOfOperation.getOwnerOp() = parent;
195
196 // The IsolatedFromAbove trait of the parent operation implies that each
197 // individual region argument has its own separate namespace, so no
198 // ambiguity.
199 if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
200 break;
201
202 // Component only needed if operation has multiple region operands. Region
203 // arguments may be optional, but we currently do not consider this.
204 if (parent->getRegions().size() > 1) {
205 auto getRegionIndex = [](Operation *o, Region *r) {
206 for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
207 if (&region == r)
208 return idx;
209 }
210 llvm_unreachable("Region not child of its parent operation");
211 };
212 regionArgOfOperation.isUnique = false;
213 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
214 }
215
216 // next parent
217 o = parent;
218 }
219
220 // Determine whether a region-argument component is not needed
221 for (Component &c : components)
222 c.skipIf(c.isRegionArgOfOp && c.isUnique);
223
224 // Find runs of nested loops and determine each loop's depth in the loop nest
225 size_t numSurroundingLoops = 0;
226 for (Component &c : llvm::reverse(components)) {
227 if (c.skip)
228 continue;
229
230 // non-skipped multi-argument operands interrupt the loop nest
231 if (c.isRegionArgOfOp) {
232 numSurroundingLoops = 0;
233 continue;
234 }
235
236 // Multiple loops in a region means each of them is the outermost loop of a
237 // new loop nest
238 if (!c.isUnique)
239 numSurroundingLoops = 0;
240
241 c.getLoopDepth() = numSurroundingLoops;
242
243 // Next loop is surrounded by one more loop
244 if (isa<CanonicalLoopOp>(c.getContainerOp()))
245 numSurroundingLoops += 1;
246 }
247
248 // In loop nests, skip all but the innermost loop that contains the depth
249 // number
250 bool isLoopNest = false;
251 for (Component &c : components) {
252 if (c.skip || c.isRegionArgOfOp)
253 continue;
254
255 if (!isLoopNest && c.getLoopDepth() >= 1) {
256 // Innermost loop of a loop nest of at least two loops
257 isLoopNest = true;
258 } else if (isLoopNest) {
259 // Non-innermost loop of a loop nest
260 c.skipIf(c.isUnique);
261
262 // If there is no surrounding loop left, this must have been the outermost
263 // loop; leave loop-nest mode for the next iteration
264 if (c.getLoopDepth() == 0)
265 isLoopNest = false;
266 }
267 }
268
269 // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
270 // we mark them as skipped only after computing loop nests)
271 for (Component &c : components)
272 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
273 !isa<CanonicalLoopOp>(c.getContainerOp()));
274
275 // Components can be skipped if they are already disambiguated by their parent
276 // (or does not have a parent)
277 bool newRegion = true;
278 for (Component &c : llvm::reverse(components)) {
279 c.skipIf(newRegion && c.isUnique);
280
281 // non-skipped components disambiguate unique children
282 if (!c.skip)
283 newRegion = true;
284
285 // ...except canonical loops that need a suffix for each nest
286 if (!c.isRegionArgOfOp && c.getContainerOp())
287 newRegion = false;
288 }
289
290 // Compile the nesting name string
291 SmallString<64> Name{prefix};
292 llvm::raw_svector_ostream NameOS(Name);
293 for (auto &c : llvm::reverse(components)) {
294 if (c.skip)
295 continue;
296
297 if (c.isRegionArgOfOp)
298 NameOS << "_r" << c.getArgIdx();
299 else if (c.getLoopDepth() >= 1)
300 NameOS << "_d" << c.getLoopDepth();
301 else
302 NameOS << "_s" << c.getOpPos();
303 }
304
305 return NameOS.str().str();
306}
307
308void OpenMPDialect::initialize() {
309 addOperations<
310#define GET_OP_LIST
311#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
312 >();
313 addAttributes<
314#define GET_ATTRDEF_LIST
315#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
316 >();
317 addTypes<
318#define GET_TYPEDEF_LIST
319#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
320 >();
321
322 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
323
324 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
325 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
326 *getContext());
327
328 // Attach default offload module interface to module op to access
329 // offload functionality through
330 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
331 *getContext());
332
333 // Attach default declare target interfaces to operations which can be marked
334 // as declare target (Global Operations and Functions/Subroutines in dialects
335 // that Fortran (or other languages that lower to MLIR) translates too
336 mlir::LLVM::GlobalOp::attachInterface<
338 *getContext());
339 mlir::LLVM::LLVMFuncOp::attachInterface<
341 *getContext());
342 mlir::func::FuncOp::attachInterface<
344}
345
346//===----------------------------------------------------------------------===//
347// Parser and printer for Allocate Clause
348//===----------------------------------------------------------------------===//
349
350/// Parse an allocate clause with allocators and a list of operands with types.
351///
352/// allocate-operand-list :: = allocate-operand |
353/// allocator-operand `,` allocate-operand-list
354/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
355/// ssa-id-and-type ::= ssa-id `:` type
356static ParseResult parseAllocateAndAllocator(
357 OpAsmParser &parser,
359 SmallVectorImpl<Type> &allocateTypes,
361 SmallVectorImpl<Type> &allocatorTypes) {
362
363 return parser.parseCommaSeparatedList([&]() {
365 Type type;
366 if (parser.parseOperand(operand) || parser.parseColonType(type))
367 return failure();
368 allocatorVars.push_back(operand);
369 allocatorTypes.push_back(type);
370 if (parser.parseArrow())
371 return failure();
372 if (parser.parseOperand(operand) || parser.parseColonType(type))
373 return failure();
374
375 allocateVars.push_back(operand);
376 allocateTypes.push_back(type);
377 return success();
378 });
379}
380
381/// Print allocate clause
383 OperandRange allocateVars,
384 TypeRange allocateTypes,
385 OperandRange allocatorVars,
386 TypeRange allocatorTypes) {
387 for (unsigned i = 0; i < allocateVars.size(); ++i) {
388 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
389 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
390 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
391 }
392}
393
394//===----------------------------------------------------------------------===//
395// Parser and printer for a clause attribute (StringEnumAttr)
396//===----------------------------------------------------------------------===//
397
398template <typename ClauseAttr>
399static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
400 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
401 StringRef enumStr;
402 SMLoc loc = parser.getCurrentLocation();
403 if (parser.parseKeyword(&enumStr))
404 return failure();
405 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
406 attr = ClauseAttr::get(parser.getContext(), *enumValue);
407 return success();
408 }
409 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
410}
411
412template <typename ClauseAttr>
413static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
414 p << stringifyEnum(attr.getValue());
415}
416
417//===----------------------------------------------------------------------===//
418// Parser and printer for Linear Clause
419//===----------------------------------------------------------------------===//
420
421/// linear ::= `linear` `(` linear-list `)`
422/// linear-list := linear-val | linear-val linear-list
423/// linear-val := ssa-id-and-type `=` ssa-id-and-type
424/// | `val` `(` ssa-id-and-type `=` ssa-id-and-type `)`
425/// | `ref` `(` ssa-id-and-type `=` ssa-id-and-type `)`
426/// | `uval` `(` ssa-id-and-type `=` ssa-id-and-type `)`
427static ParseResult parseLinearClause(
428 OpAsmParser &parser,
430 SmallVectorImpl<Type> &linearTypes,
432 SmallVectorImpl<Type> &linearStepTypes, ArrayAttr &linearModifiers) {
433 SmallVector<Attribute> modifiers;
434 auto result = parser.parseCommaSeparatedList([&]() {
436 Type type, stepType;
438
439 std::optional<omp::LinearModifier> linearModifier;
440 if (succeeded(parser.parseOptionalKeyword("val"))) {
441 linearModifier = omp::LinearModifier::val;
442 } else if (succeeded(parser.parseOptionalKeyword("ref"))) {
443 linearModifier = omp::LinearModifier::ref;
444 } else if (succeeded(parser.parseOptionalKeyword("uval"))) {
445 linearModifier = omp::LinearModifier::uval;
446 }
447
448 bool hasLinearModifierParens = linearModifier.has_value();
449 if (hasLinearModifierParens && parser.parseLParen())
450 return failure();
451
452 if (parser.parseOperand(var) || parser.parseColonType(type) ||
453 parser.parseEqual() || parser.parseOperand(stepVar) ||
454 parser.parseColonType(stepType))
455 return failure();
456
457 if (hasLinearModifierParens && parser.parseRParen())
458 return failure();
459
460 linearVars.push_back(var);
461 linearTypes.push_back(type);
462 linearStepVars.push_back(stepVar);
463 linearStepTypes.push_back(stepType);
464 if (linearModifier) {
465 modifiers.push_back(
466 omp::LinearModifierAttr::get(parser.getContext(), *linearModifier));
467 } else {
468 modifiers.push_back(UnitAttr::get(parser.getContext()));
469 }
470 return success();
471 });
472 if (failed(result))
473 return failure();
474 linearModifiers = ArrayAttr::get(parser.getContext(), modifiers);
475 return success();
476}
477
478/// Print Linear Clause
480 ValueRange linearVars, TypeRange linearTypes,
481 ValueRange linearStepVars, TypeRange stepVarTypes,
482 ArrayAttr linearModifiers) {
483 size_t linearVarsSize = linearVars.size();
484 for (unsigned i = 0; i < linearVarsSize; ++i) {
485 if (i != 0)
486 p << ", ";
487 // Print modifier keyword wrapper if present.
488 Attribute modAttr = linearModifiers ? linearModifiers[i] : nullptr;
489 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) : nullptr;
490 if (mod) {
491 p << omp::stringifyLinearModifier(mod.getValue()) << "(";
492 }
493 p << linearVars[i] << " : " << linearTypes[i];
494 p << " = " << linearStepVars[i] << " : " << stepVarTypes[i];
495 if (mod)
496 p << ")";
497 }
498}
499
500//===----------------------------------------------------------------------===//
501// Verifier for Linear modifier
502//===----------------------------------------------------------------------===//
503
504/// OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or
505/// uval only on a declare simd directive."
506/// Also verifies that modifier count matches variable count.
507static LogicalResult
508verifyLinearModifiers(Operation *op, std::optional<ArrayAttr> linearModifiers,
509 OperandRange linearVars, bool isDeclareSimd = false) {
510 if (!linearModifiers)
511 return success();
512 if (linearModifiers->size() != linearVars.size())
513 return op->emitOpError()
514 << "expected as many linear modifiers as linear variables";
515 if (!isDeclareSimd) {
516 for (Attribute attr : *linearModifiers) {
517 if (!attr)
518 continue;
519 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
520 if (!modAttr)
521 continue;
522 omp::LinearModifier mod = modAttr.getValue();
523 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
524 return op->emitOpError()
525 << "linear modifier '" << omp::stringifyLinearModifier(mod)
526 << "' may only be specified on a declare simd directive";
527 }
528 }
529 return success();
530}
531
532//===----------------------------------------------------------------------===//
533// Verifier for Nontemporal Clause
534//===----------------------------------------------------------------------===//
535
536static LogicalResult verifyNontemporalClause(Operation *op,
537 OperandRange nontemporalVars) {
538
539 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
540 DenseSet<Value> nontemporalItems;
541 for (const auto &it : nontemporalVars)
542 if (!nontemporalItems.insert(it).second)
543 return op->emitOpError() << "nontemporal variable used more than once";
544
545 return success();
546}
547
548//===----------------------------------------------------------------------===//
549// Parser, verifier and printer for Aligned Clause
550//===----------------------------------------------------------------------===//
551static LogicalResult verifyAlignedClause(Operation *op,
552 std::optional<ArrayAttr> alignments,
553 OperandRange alignedVars) {
554 // Check if number of alignment values equals to number of aligned variables
555 if (!alignedVars.empty()) {
556 if (!alignments || alignments->size() != alignedVars.size())
557 return op->emitOpError()
558 << "expected as many alignment values as aligned variables";
559 } else {
560 if (alignments)
561 return op->emitOpError() << "unexpected alignment values attribute";
562 return success();
563 }
564
565 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
566 DenseSet<Value> alignedItems;
567 for (auto it : alignedVars)
568 if (!alignedItems.insert(it).second)
569 return op->emitOpError() << "aligned variable used more than once";
570
571 if (!alignments)
572 return success();
573
574 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
575 for (unsigned i = 0; i < (*alignments).size(); ++i) {
576 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
577 if (intAttr.getValue().sle(0))
578 return op->emitOpError() << "alignment should be greater than 0";
579 } else {
580 return op->emitOpError() << "expected integer alignment";
581 }
582 }
583
584 return success();
585}
586
587/// aligned ::= `aligned` `(` aligned-list `)`
588/// aligned-list := aligned-val | aligned-val aligned-list
589/// aligned-val := ssa-id-and-type `->` alignment
590static ParseResult
593 SmallVectorImpl<Type> &alignedTypes,
594 ArrayAttr &alignmentsAttr) {
595 SmallVector<Attribute> alignmentVec;
596 if (failed(parser.parseCommaSeparatedList([&]() {
597 if (parser.parseOperand(alignedVars.emplace_back()) ||
598 parser.parseColonType(alignedTypes.emplace_back()) ||
599 parser.parseArrow() ||
600 parser.parseAttribute(alignmentVec.emplace_back())) {
601 return failure();
602 }
603 return success();
604 })))
605 return failure();
606 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
607 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
608 return success();
609}
610
611/// Print Aligned Clause
613 ValueRange alignedVars, TypeRange alignedTypes,
614 std::optional<ArrayAttr> alignments) {
615 for (unsigned i = 0; i < alignedVars.size(); ++i) {
616 if (i != 0)
617 p << ", ";
618 p << alignedVars[i] << " : " << alignedVars[i].getType();
619 p << " -> " << (*alignments)[i];
620 }
621}
622
623//===----------------------------------------------------------------------===//
624// Parser, printer and verifier for Schedule Clause
625//===----------------------------------------------------------------------===//
626
627static ParseResult
629 SmallVectorImpl<SmallString<12>> &modifiers) {
630 if (modifiers.size() > 2)
631 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
632 for (const auto &mod : modifiers) {
633 // Translate the string. If it has no value, then it was not a valid
634 // modifier!
635 auto symbol = symbolizeScheduleModifier(mod);
636 if (!symbol)
637 return parser.emitError(parser.getNameLoc())
638 << " unknown modifier type: " << mod;
639 }
640
641 // If we have one modifier that is "simd", then stick a "none" modiifer in
642 // index 0.
643 if (modifiers.size() == 1) {
644 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
645 modifiers.push_back(modifiers[0]);
646 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
647 }
648 } else if (modifiers.size() == 2) {
649 // If there are two modifier:
650 // First modifier should not be simd, second one should be simd
651 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
652 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
653 return parser.emitError(parser.getNameLoc())
654 << " incorrect modifier order";
655 }
656 return success();
657}
658
659/// schedule ::= `schedule` `(` sched-list `)`
660/// sched-list ::= sched-val | sched-val sched-list |
661/// sched-val `,` sched-modifier
662/// sched-val ::= sched-with-chunk | sched-wo-chunk
663/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
664/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
665/// sched-wo-chunk ::= `auto` | `runtime`
666/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
667/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
668static ParseResult
669parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
670 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
671 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
672 Type &chunkType) {
673 StringRef keyword;
674 if (parser.parseKeyword(&keyword))
675 return failure();
676 std::optional<mlir::omp::ClauseScheduleKind> schedule =
677 symbolizeClauseScheduleKind(keyword);
678 if (!schedule)
679 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
680
681 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
682 switch (*schedule) {
683 case ClauseScheduleKind::Static:
684 case ClauseScheduleKind::Dynamic:
685 case ClauseScheduleKind::Guided:
686 if (succeeded(parser.parseOptionalEqual())) {
687 chunkSize = OpAsmParser::UnresolvedOperand{};
688 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
689 return failure();
690 } else {
691 chunkSize = std::nullopt;
692 }
693 break;
694 case ClauseScheduleKind::Auto:
695 case ClauseScheduleKind::Runtime:
696 case ClauseScheduleKind::Distribute:
697 chunkSize = std::nullopt;
698 }
699
700 // If there is a comma, we have one or more modifiers..
702 while (succeeded(parser.parseOptionalComma())) {
703 StringRef mod;
704 if (parser.parseKeyword(&mod))
705 return failure();
706 modifiers.push_back(mod);
707 }
708
709 if (verifyScheduleModifiers(parser, modifiers))
710 return failure();
711
712 if (!modifiers.empty()) {
713 SMLoc loc = parser.getCurrentLocation();
714 if (std::optional<ScheduleModifier> mod =
715 symbolizeScheduleModifier(modifiers[0])) {
716 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
717 } else {
718 return parser.emitError(loc, "invalid schedule modifier");
719 }
720 // Only SIMD attribute is allowed here!
721 if (modifiers.size() > 1) {
722 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
723 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
724 }
725 }
726
727 return success();
728}
729
730/// Print schedule clause
732 ClauseScheduleKindAttr scheduleKind,
733 ScheduleModifierAttr scheduleMod,
734 UnitAttr scheduleSimd, Value scheduleChunk,
735 Type scheduleChunkType) {
736 p << stringifyClauseScheduleKind(scheduleKind.getValue());
737 if (scheduleChunk)
738 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
739 if (scheduleMod)
740 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
741 if (scheduleSimd)
742 p << ", simd";
743}
744
745//===----------------------------------------------------------------------===//
746// Parser and printer for Order Clause
747//===----------------------------------------------------------------------===//
748
749// order ::= `order` `(` [order-modifier ':'] concurrent `)`
750// order-modifier ::= reproducible | unconstrained
751static ParseResult parseOrderClause(OpAsmParser &parser,
752 ClauseOrderKindAttr &order,
753 OrderModifierAttr &orderMod) {
754 StringRef enumStr;
755 SMLoc loc = parser.getCurrentLocation();
756 if (parser.parseKeyword(&enumStr))
757 return failure();
758 if (std::optional<OrderModifier> enumValue =
759 symbolizeOrderModifier(enumStr)) {
760 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
761 if (parser.parseOptionalColon())
762 return failure();
763 loc = parser.getCurrentLocation();
764 if (parser.parseKeyword(&enumStr))
765 return failure();
766 }
767 if (std::optional<ClauseOrderKind> enumValue =
768 symbolizeClauseOrderKind(enumStr)) {
769 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
770 return success();
771 }
772 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
773}
774
776 ClauseOrderKindAttr order,
777 OrderModifierAttr orderMod) {
778 if (orderMod)
779 p << stringifyOrderModifier(orderMod.getValue()) << ":";
780 if (order)
781 p << stringifyClauseOrderKind(order.getValue());
782}
783
784template <typename ClauseTypeAttr, typename ClauseType>
785static ParseResult
786parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
787 std::optional<OpAsmParser::UnresolvedOperand> &operand,
788 Type &operandType,
789 std::optional<ClauseType> (*symbolizeClause)(StringRef),
790 StringRef clauseName) {
791 StringRef enumStr;
792 if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
793 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
794 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
795 if (parser.parseComma())
796 return failure();
797 } else {
798 return parser.emitError(parser.getCurrentLocation())
799 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
800 ;
801 }
802 }
803
805 if (succeeded(parser.parseOperand(var))) {
806 operand = var;
807 } else {
808 return parser.emitError(parser.getCurrentLocation())
809 << "expected " << clauseName << " operand";
810 }
811
812 if (operand.has_value()) {
813 if (parser.parseColonType(operandType))
814 return failure();
815 }
816
817 return success();
818}
819
820template <typename ClauseTypeAttr, typename ClauseType>
821static void
823 ClauseTypeAttr prescriptiveness, Value operand,
824 mlir::Type operandType,
825 StringRef (*stringifyClauseType)(ClauseType)) {
826
827 if (prescriptiveness)
828 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
829
830 if (operand)
831 p << operand << ": " << operandType;
832}
833
834//===----------------------------------------------------------------------===//
835// Parser and printer for grainsize Clause
836//===----------------------------------------------------------------------===//
837
838// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
839static ParseResult
840parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
841 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
842 Type &grainsizeType) {
844 parser, grainsizeMod, grainsize, grainsizeType,
845 &symbolizeClauseGrainsizeType, "grainsize");
846}
847
849 ClauseGrainsizeTypeAttr grainsizeMod,
850 Value grainsize, mlir::Type grainsizeType) {
852 p, op, grainsizeMod, grainsize, grainsizeType,
853 &stringifyClauseGrainsizeType);
854}
855
856//===----------------------------------------------------------------------===//
857// Parser and printer for num_tasks Clause
858//===----------------------------------------------------------------------===//
859
860// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
861static ParseResult
862parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
863 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
864 Type &numTasksType) {
866 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
867 "num_tasks");
868}
869
871 ClauseNumTasksTypeAttr numTasksMod,
872 Value numTasks, mlir::Type numTasksType) {
874 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
875}
876
877//===----------------------------------------------------------------------===//
878// Parsers for operations including clauses that define entry block arguments.
879//===----------------------------------------------------------------------===//
880
881namespace {
882struct MapParseArgs {
883 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
884 SmallVectorImpl<Type> &types;
885 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
886 SmallVectorImpl<Type> &types)
887 : vars(vars), types(types) {}
888};
889struct PrivateParseArgs {
890 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
891 llvm::SmallVectorImpl<Type> &types;
892 ArrayAttr &syms;
893 UnitAttr &needsBarrier;
894 DenseI64ArrayAttr *mapIndices;
895 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
896 SmallVectorImpl<Type> &types, ArrayAttr &syms,
897 UnitAttr &needsBarrier,
898 DenseI64ArrayAttr *mapIndices = nullptr)
899 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
900 mapIndices(mapIndices) {}
901};
902
903struct ReductionParseArgs {
904 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
905 SmallVectorImpl<Type> &types;
906 DenseBoolArrayAttr &byref;
907 ArrayAttr &syms;
908 ReductionModifierAttr *modifier;
909 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
910 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
911 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
912 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
913};
914
915struct AllRegionParseArgs {
916 std::optional<MapParseArgs> hasDeviceAddrArgs;
917 std::optional<MapParseArgs> hostEvalArgs;
918 std::optional<ReductionParseArgs> inReductionArgs;
919 std::optional<MapParseArgs> mapArgs;
920 std::optional<PrivateParseArgs> privateArgs;
921 std::optional<ReductionParseArgs> reductionArgs;
922 std::optional<ReductionParseArgs> taskReductionArgs;
923 std::optional<MapParseArgs> useDeviceAddrArgs;
924 std::optional<MapParseArgs> useDevicePtrArgs;
925};
926} // namespace
927
928static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
929 return "private_barrier";
930}
931
932static ParseResult parseClauseWithRegionArgs(
933 OpAsmParser &parser,
937 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
938 DenseBoolArrayAttr *byref = nullptr,
939 ReductionModifierAttr *modifier = nullptr,
940 UnitAttr *needsBarrier = nullptr) {
942 SmallVector<int64_t> mapIndicesVec;
943 SmallVector<bool> isByRefVec;
944 unsigned regionArgOffset = regionPrivateArgs.size();
945
946 if (parser.parseLParen())
947 return failure();
948
949 if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
950 StringRef enumStr;
951 if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
952 parser.parseComma())
953 return failure();
954 std::optional<ReductionModifier> enumValue =
955 symbolizeReductionModifier(enumStr);
956 if (!enumValue.has_value())
957 return failure();
958 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
959 if (!*modifier)
960 return failure();
961 }
962
963 if (parser.parseCommaSeparatedList([&]() {
964 if (byref)
965 isByRefVec.push_back(
966 parser.parseOptionalKeyword("byref").succeeded());
967
968 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
969 return failure();
970
971 if (parser.parseOperand(operands.emplace_back()) ||
972 parser.parseArrow() ||
973 parser.parseArgument(regionPrivateArgs.emplace_back()))
974 return failure();
975
976 if (mapIndices) {
977 if (parser.parseOptionalLSquare().succeeded()) {
978 if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
979 parser.parseInteger(mapIndicesVec.emplace_back()) ||
980 parser.parseRSquare())
981 return failure();
982 } else {
983 mapIndicesVec.push_back(-1);
984 }
985 }
986
987 return success();
988 }))
989 return failure();
990
991 if (parser.parseColon())
992 return failure();
993
994 if (parser.parseCommaSeparatedList([&]() {
995 if (parser.parseType(types.emplace_back()))
996 return failure();
997
998 return success();
999 }))
1000 return failure();
1001
1002 if (operands.size() != types.size())
1003 return failure();
1004
1005 if (parser.parseRParen())
1006 return failure();
1007
1008 if (needsBarrier) {
1010 .succeeded())
1011 *needsBarrier = mlir::UnitAttr::get(parser.getContext());
1012 }
1013
1014 auto *argsBegin = regionPrivateArgs.begin();
1015 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
1016 argsBegin + regionArgOffset + types.size());
1017 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1018 prv.type = type;
1019 }
1020
1021 if (symbols) {
1022 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
1023 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
1024 }
1025
1026 if (!mapIndicesVec.empty())
1027 *mapIndices =
1028 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
1029
1030 if (byref)
1031 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
1032
1033 return success();
1034}
1035
1036static ParseResult parseBlockArgClause(
1037 OpAsmParser &parser,
1039 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1040 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1041 if (!mapArgs)
1042 return failure();
1043
1044 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
1045 entryBlockArgs)))
1046 return failure();
1047 }
1048 return success();
1049}
1050
1051static ParseResult parseBlockArgClause(
1052 OpAsmParser &parser,
1054 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1055 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1056 if (!privateArgs)
1057 return failure();
1058
1059 if (failed(parseClauseWithRegionArgs(
1060 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1061 &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1062 /*modifier=*/nullptr, &privateArgs->needsBarrier)))
1063 return failure();
1064 }
1065 return success();
1066}
1067
1068static ParseResult parseBlockArgClause(
1069 OpAsmParser &parser,
1071 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1072 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1073 if (!reductionArgs)
1074 return failure();
1075 if (failed(parseClauseWithRegionArgs(
1076 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1077 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1078 reductionArgs->modifier)))
1079 return failure();
1080 }
1081 return success();
1082}
1083
1084static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1085 AllRegionParseArgs args) {
1087
1088 if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1089 args.hasDeviceAddrArgs)))
1090 return parser.emitError(parser.getCurrentLocation())
1091 << "invalid `has_device_addr` format";
1092
1093 if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1094 args.hostEvalArgs)))
1095 return parser.emitError(parser.getCurrentLocation())
1096 << "invalid `host_eval` format";
1097
1098 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1099 args.inReductionArgs)))
1100 return parser.emitError(parser.getCurrentLocation())
1101 << "invalid `in_reduction` format";
1102
1103 if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1104 args.mapArgs)))
1105 return parser.emitError(parser.getCurrentLocation())
1106 << "invalid `map_entries` format";
1107
1108 if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1109 args.privateArgs)))
1110 return parser.emitError(parser.getCurrentLocation())
1111 << "invalid `private` format";
1112
1113 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1114 args.reductionArgs)))
1115 return parser.emitError(parser.getCurrentLocation())
1116 << "invalid `reduction` format";
1117
1118 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1119 args.taskReductionArgs)))
1120 return parser.emitError(parser.getCurrentLocation())
1121 << "invalid `task_reduction` format";
1122
1123 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1124 args.useDeviceAddrArgs)))
1125 return parser.emitError(parser.getCurrentLocation())
1126 << "invalid `use_device_addr` format";
1127
1128 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1129 args.useDevicePtrArgs)))
1130 return parser.emitError(parser.getCurrentLocation())
1131 << "invalid `use_device_addr` format";
1132
1133 return parser.parseRegion(region, entryBlockArgs);
1134}
1135
1136// These parseXyz functions correspond to the custom<Xyz> definitions
1137// in the .td file(s).
1138static ParseResult parseTargetOpRegion(
1139 OpAsmParser &parser, Region &region,
1141 SmallVectorImpl<Type> &hasDeviceAddrTypes,
1143 SmallVectorImpl<Type> &hostEvalTypes,
1145 SmallVectorImpl<Type> &inReductionTypes,
1146 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1148 SmallVectorImpl<Type> &mapTypes,
1150 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1151 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1152 AllRegionParseArgs args;
1153 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1154 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1155 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1156 inReductionByref, inReductionSyms);
1157 args.mapArgs.emplace(mapVars, mapTypes);
1158 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1159 privateNeedsBarrier, &privateMaps);
1160 return parseBlockArgRegion(parser, region, args);
1161}
1162
1164 OpAsmParser &parser, Region &region,
1166 SmallVectorImpl<Type> &inReductionTypes,
1167 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1169 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1170 UnitAttr &privateNeedsBarrier) {
1171 AllRegionParseArgs args;
1172 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1173 inReductionByref, inReductionSyms);
1174 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1175 privateNeedsBarrier);
1176 return parseBlockArgRegion(parser, region, args);
1177}
1178
1180 OpAsmParser &parser, Region &region,
1182 SmallVectorImpl<Type> &inReductionTypes,
1183 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1185 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1186 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1188 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1189 ArrayAttr &reductionSyms) {
1190 AllRegionParseArgs args;
1191 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1192 inReductionByref, inReductionSyms);
1193 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1194 privateNeedsBarrier);
1195 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1196 reductionSyms, &reductionMod);
1197 return parseBlockArgRegion(parser, region, args);
1198}
1199
1200static ParseResult parsePrivateRegion(
1201 OpAsmParser &parser, Region &region,
1203 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1204 UnitAttr &privateNeedsBarrier) {
1205 AllRegionParseArgs args;
1206 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1207 privateNeedsBarrier);
1208 return parseBlockArgRegion(parser, region, args);
1209}
1210
1212 OpAsmParser &parser, Region &region,
1214 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1215 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1217 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1218 ArrayAttr &reductionSyms) {
1219 AllRegionParseArgs args;
1220 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1221 privateNeedsBarrier);
1222 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1223 reductionSyms, &reductionMod);
1224 return parseBlockArgRegion(parser, region, args);
1225}
1226
1227static ParseResult parseTaskReductionRegion(
1228 OpAsmParser &parser, Region &region,
1230 SmallVectorImpl<Type> &taskReductionTypes,
1231 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1232 AllRegionParseArgs args;
1233 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1234 taskReductionByref, taskReductionSyms);
1235 return parseBlockArgRegion(parser, region, args);
1236}
1237
1239 OpAsmParser &parser, Region &region,
1241 SmallVectorImpl<Type> &useDeviceAddrTypes,
1243 SmallVectorImpl<Type> &useDevicePtrTypes) {
1244 AllRegionParseArgs args;
1245 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1246 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1247 return parseBlockArgRegion(parser, region, args);
1248}
1249
1250//===----------------------------------------------------------------------===//
1251// Printers for operations including clauses that define entry block arguments.
1252//===----------------------------------------------------------------------===//
1253
1254namespace {
1255struct MapPrintArgs {
1256 ValueRange vars;
1257 TypeRange types;
1258 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1259};
1260struct PrivatePrintArgs {
1261 ValueRange vars;
1262 TypeRange types;
1263 ArrayAttr syms;
1264 UnitAttr needsBarrier;
1265 DenseI64ArrayAttr mapIndices;
1266 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1267 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1268 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1269 mapIndices(mapIndices) {}
1270};
1271struct ReductionPrintArgs {
1272 ValueRange vars;
1273 TypeRange types;
1274 DenseBoolArrayAttr byref;
1275 ArrayAttr syms;
1276 ReductionModifierAttr modifier;
1277 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1278 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1279 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1280};
1281struct AllRegionPrintArgs {
1282 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1283 std::optional<MapPrintArgs> hostEvalArgs;
1284 std::optional<ReductionPrintArgs> inReductionArgs;
1285 std::optional<MapPrintArgs> mapArgs;
1286 std::optional<PrivatePrintArgs> privateArgs;
1287 std::optional<ReductionPrintArgs> reductionArgs;
1288 std::optional<ReductionPrintArgs> taskReductionArgs;
1289 std::optional<MapPrintArgs> useDeviceAddrArgs;
1290 std::optional<MapPrintArgs> useDevicePtrArgs;
1291};
1292} // namespace
1293
1295 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1296 ValueRange argsSubrange, ValueRange operands, TypeRange types,
1297 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1298 DenseBoolArrayAttr byref = nullptr,
1299 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1300 if (argsSubrange.empty())
1301 return;
1302
1303 p << clauseName << "(";
1304
1305 if (modifier)
1306 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1307
1308 if (!symbols) {
1309 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1310 symbols = ArrayAttr::get(ctx, values);
1311 }
1312
1313 if (!mapIndices) {
1314 llvm::SmallVector<int64_t> values(operands.size(), -1);
1315 mapIndices = DenseI64ArrayAttr::get(ctx, values);
1316 }
1317
1318 if (!byref) {
1319 mlir::SmallVector<bool> values(operands.size(), false);
1320 byref = DenseBoolArrayAttr::get(ctx, values);
1321 }
1322
1323 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1324 mapIndices.asArrayRef(),
1325 byref.asArrayRef()),
1326 p, [&p](auto t) {
1327 auto [op, arg, sym, map, isByRef] = t;
1328 if (isByRef)
1329 p << "byref ";
1330 if (sym)
1331 p << sym << " ";
1332
1333 p << op << " -> " << arg;
1334
1335 if (map != -1)
1336 p << " [map_idx=" << map << "]";
1337 });
1338 p << " : ";
1339 llvm::interleaveComma(types, p);
1340 p << ") ";
1341
1342 if (needsBarrier)
1343 p << getPrivateNeedsBarrierSpelling() << " ";
1344}
1345
1347 StringRef clauseName, ValueRange argsSubrange,
1348 std::optional<MapPrintArgs> mapArgs) {
1349 if (mapArgs)
1350 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1351 mapArgs->types);
1352}
1353
1355 StringRef clauseName, ValueRange argsSubrange,
1356 std::optional<PrivatePrintArgs> privateArgs) {
1357 if (privateArgs)
1359 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1360 privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1361 /*modifier=*/nullptr, privateArgs->needsBarrier);
1362}
1363
1364static void
1365printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1366 ValueRange argsSubrange,
1367 std::optional<ReductionPrintArgs> reductionArgs) {
1368 if (reductionArgs)
1369 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1370 reductionArgs->vars, reductionArgs->types,
1371 reductionArgs->syms, /*mapIndices=*/nullptr,
1372 reductionArgs->byref, reductionArgs->modifier);
1373}
1374
1376 const AllRegionPrintArgs &args) {
1377 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1378 MLIRContext *ctx = op->getContext();
1379
1380 printBlockArgClause(p, ctx, "has_device_addr",
1381 iface.getHasDeviceAddrBlockArgs(),
1382 args.hasDeviceAddrArgs);
1383 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1384 args.hostEvalArgs);
1385 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1386 args.inReductionArgs);
1387 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1388 args.mapArgs);
1389 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1390 args.privateArgs);
1391 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1392 args.reductionArgs);
1393 printBlockArgClause(p, ctx, "task_reduction",
1394 iface.getTaskReductionBlockArgs(),
1395 args.taskReductionArgs);
1396 printBlockArgClause(p, ctx, "use_device_addr",
1397 iface.getUseDeviceAddrBlockArgs(),
1398 args.useDeviceAddrArgs);
1399 printBlockArgClause(p, ctx, "use_device_ptr",
1400 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1401
1402 p.printRegion(region, /*printEntryBlockArgs=*/false);
1403}
1404
1405// These parseXyz functions correspond to the custom<Xyz> definitions
1406// in the .td file(s).
1408 OpAsmPrinter &p, Operation *op, Region &region,
1409 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1410 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1411 ValueRange inReductionVars, TypeRange inReductionTypes,
1412 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1413 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1414 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1415 DenseI64ArrayAttr privateMaps) {
1416 AllRegionPrintArgs args;
1417 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1418 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1419 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1420 inReductionByref, inReductionSyms);
1421 args.mapArgs.emplace(mapVars, mapTypes);
1422 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1423 privateNeedsBarrier, privateMaps);
1424 printBlockArgRegion(p, op, region, args);
1425}
1426
1428 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1429 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1430 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1431 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1432 AllRegionPrintArgs args;
1433 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1434 inReductionByref, inReductionSyms);
1435 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1436 privateNeedsBarrier,
1437 /*mapIndices=*/nullptr);
1438 printBlockArgRegion(p, op, region, args);
1439}
1440
1442 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1443 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1444 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1445 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1446 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1447 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1448 ArrayAttr reductionSyms) {
1449 AllRegionPrintArgs args;
1450 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1451 inReductionByref, inReductionSyms);
1452 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1453 privateNeedsBarrier,
1454 /*mapIndices=*/nullptr);
1455 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1456 reductionSyms, reductionMod);
1457 printBlockArgRegion(p, op, region, args);
1458}
1459
1461 ValueRange privateVars, TypeRange privateTypes,
1462 ArrayAttr privateSyms,
1463 UnitAttr privateNeedsBarrier) {
1464 AllRegionPrintArgs args;
1465 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1466 privateNeedsBarrier,
1467 /*mapIndices=*/nullptr);
1468 printBlockArgRegion(p, op, region, args);
1469}
1470
1472 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1473 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1474 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1475 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1476 ArrayAttr reductionSyms) {
1477 AllRegionPrintArgs args;
1478 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1479 privateNeedsBarrier,
1480 /*mapIndices=*/nullptr);
1481 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1482 reductionSyms, reductionMod);
1483 printBlockArgRegion(p, op, region, args);
1484}
1485
1487 Region &region,
1488 ValueRange taskReductionVars,
1489 TypeRange taskReductionTypes,
1490 DenseBoolArrayAttr taskReductionByref,
1491 ArrayAttr taskReductionSyms) {
1492 AllRegionPrintArgs args;
1493 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1494 taskReductionByref, taskReductionSyms);
1495 printBlockArgRegion(p, op, region, args);
1496}
1497
1499 Region &region,
1500 ValueRange useDeviceAddrVars,
1501 TypeRange useDeviceAddrTypes,
1502 ValueRange useDevicePtrVars,
1503 TypeRange useDevicePtrTypes) {
1504 AllRegionPrintArgs args;
1505 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1506 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1507 printBlockArgRegion(p, op, region, args);
1508}
1509
1510template <typename ParsePrefixFn>
1511static ParseResult parseSplitIteratedList(
1512 OpAsmParser &parser,
1514 SmallVectorImpl<Type> &iteratedTypes,
1516 SmallVectorImpl<Type> &plainTypes, ParsePrefixFn &&parsePrefix) {
1517
1518 return parser.parseCommaSeparatedList([&]() -> ParseResult {
1519 if (failed(parsePrefix()))
1520 return failure();
1521
1523 Type ty;
1524 if (parser.parseOperand(v) || parser.parseColonType(ty))
1525 return failure();
1526
1527 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1528 iteratedVars.push_back(v);
1529 iteratedTypes.push_back(ty);
1530 } else {
1531 plainVars.push_back(v);
1532 plainTypes.push_back(ty);
1533 }
1534 return success();
1535 });
1536}
1537
1538template <typename PrintPrefixFn>
1540 TypeRange iteratedTypes,
1541 ValueRange plainVars, TypeRange plainTypes,
1542 PrintPrefixFn &&printPrefixForPlain,
1543 PrintPrefixFn &&printPrefixForIterated) {
1544
1545 bool first = true;
1546 auto emit = [&](Value v, Type t, auto &&printPrefix) {
1547 if (!first)
1548 p << ", ";
1549 printPrefix(v, t);
1550 p << v << " : " << t;
1551 first = false;
1552 };
1553
1554 for (unsigned i = 0; i < iteratedVars.size(); ++i)
1555 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1556 for (unsigned i = 0; i < plainVars.size(); ++i)
1557 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1558}
1559
1560/// Verifies Reduction Clause
1561static LogicalResult
1562verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1563 OperandRange reductionVars,
1564 std::optional<ArrayRef<bool>> reductionByref) {
1565 if (!reductionVars.empty()) {
1566 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1567 return op->emitOpError()
1568 << "expected as many reduction symbol references "
1569 "as reduction variables";
1570 if (reductionByref && reductionByref->size() != reductionVars.size())
1571 return op->emitError() << "expected as many reduction variable by "
1572 "reference attributes as reduction variables";
1573 } else {
1574 if (reductionSyms)
1575 return op->emitOpError() << "unexpected reduction symbol references";
1576 return success();
1577 }
1578
1579 // TODO: The followings should be done in
1580 // SymbolUserOpInterface::verifySymbolUses.
1581 DenseSet<Value> accumulators;
1582 for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1583 Value accum = std::get<0>(args);
1584
1585 if (!accumulators.insert(accum).second)
1586 return op->emitOpError() << "accumulator variable used more than once";
1587
1588 Type varType = accum.getType();
1589 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1590 auto decl =
1592 if (!decl)
1593 return op->emitOpError() << "expected symbol reference " << symbolRef
1594 << " to point to a reduction declaration";
1595
1596 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1597 return op->emitOpError()
1598 << "expected accumulator (" << varType
1599 << ") to be the same type as reduction declaration ("
1600 << decl.getAccumulatorType() << ")";
1601 }
1602
1603 return success();
1604}
1605
1606//===----------------------------------------------------------------------===//
1607// Parser, printer and verifier for Copyprivate
1608//===----------------------------------------------------------------------===//
1609
1610/// copyprivate-entry-list ::= copyprivate-entry
1611/// | copyprivate-entry-list `,` copyprivate-entry
1612/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1613static ParseResult parseCopyprivate(
1614 OpAsmParser &parser,
1616 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1618 if (failed(parser.parseCommaSeparatedList([&]() {
1619 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1620 parser.parseArrow() ||
1621 parser.parseAttribute(symsVec.emplace_back()) ||
1622 parser.parseColonType(copyprivateTypes.emplace_back()))
1623 return failure();
1624 return success();
1625 })))
1626 return failure();
1627 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1628 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1629 return success();
1630}
1631
1632/// Print Copyprivate clause
1634 OperandRange copyprivateVars,
1635 TypeRange copyprivateTypes,
1636 std::optional<ArrayAttr> copyprivateSyms) {
1637 if (!copyprivateSyms.has_value())
1638 return;
1639 llvm::interleaveComma(
1640 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1641 [&](const auto &args) {
1642 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1643 << std::get<2>(args);
1644 });
1645}
1646
1647/// Verifies CopyPrivate Clause
1648static LogicalResult
1650 std::optional<ArrayAttr> copyprivateSyms) {
1651 size_t copyprivateSymsSize =
1652 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1653 if (copyprivateSymsSize != copyprivateVars.size())
1654 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1655 << copyprivateVars.size()
1656 << ") and functions (= " << copyprivateSymsSize
1657 << "), both must be equal";
1658 if (!copyprivateSyms.has_value())
1659 return success();
1660
1661 for (auto copyprivateVarAndSym :
1662 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1663 auto symbolRef =
1664 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1665 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1666 funcOp;
1667 if (mlir::func::FuncOp mlirFuncOp =
1669 symbolRef))
1670 funcOp = mlirFuncOp;
1671 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1673 op, symbolRef))
1674 funcOp = llvmFuncOp;
1675
1676 auto getNumArguments = [&] {
1677 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1678 };
1679
1680 auto getArgumentType = [&](unsigned i) {
1681 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1682 *funcOp);
1683 };
1684
1685 if (!funcOp)
1686 return op->emitOpError() << "expected symbol reference " << symbolRef
1687 << " to point to a copy function";
1688
1689 if (getNumArguments() != 2)
1690 return op->emitOpError()
1691 << "expected copy function " << symbolRef << " to have 2 operands";
1692
1693 Type argTy = getArgumentType(0);
1694 if (argTy != getArgumentType(1))
1695 return op->emitOpError() << "expected copy function " << symbolRef
1696 << " arguments to have the same type";
1697
1698 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1699 if (argTy != varType)
1700 return op->emitOpError()
1701 << "expected copy function arguments' type (" << argTy
1702 << ") to be the same as copyprivate variable's type (" << varType
1703 << ")";
1704 }
1705
1706 return success();
1707}
1708
1709//===----------------------------------------------------------------------===//
1710// Parser, printer and verifier for DependVarList
1711//===----------------------------------------------------------------------===//
1712
1713/// depend-entry-list ::= depend-entry
1714/// | depend-entry-list `,` depend-entry
1715/// depend-entry ::= depend-kind `->` ssa-id `:` type
1716/// | depend-kind `->` ssa-id `:` iterated-type
1717static ParseResult parseDependVarList(
1718 OpAsmParser &parser,
1720 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds,
1722 SmallVectorImpl<Type> &iteratedTypes, ArrayAttr &iteratedKinds) {
1725 if (failed(parser.parseCommaSeparatedList([&]() {
1726 StringRef keyword;
1727 OpAsmParser::UnresolvedOperand operand;
1728 Type ty;
1729 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1730 parser.parseOperand(operand) || parser.parseColonType(ty))
1731 return failure();
1732 std::optional<ClauseTaskDepend> keywordDepend =
1733 symbolizeClauseTaskDepend(keyword);
1734 if (!keywordDepend)
1735 return failure();
1736 auto kindAttr =
1737 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1738 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1739 iteratedVars.push_back(operand);
1740 iteratedTypes.push_back(ty);
1741 iterKindsVec.push_back(kindAttr);
1742 } else {
1743 dependVars.push_back(operand);
1744 dependTypes.push_back(ty);
1745 kindsVec.push_back(kindAttr);
1746 }
1747 return success();
1748 })))
1749 return failure();
1750 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1751 dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1752 SmallVector<Attribute> iterKinds(iterKindsVec.begin(), iterKindsVec.end());
1753 iteratedKinds = ArrayAttr::get(parser.getContext(), iterKinds);
1754 return success();
1755}
1756
1757/// Print Depend clause
1759 OperandRange dependVars, TypeRange dependTypes,
1760 std::optional<ArrayAttr> dependKinds,
1761 OperandRange iteratedVars,
1762 TypeRange iteratedTypes,
1763 std::optional<ArrayAttr> iteratedKinds) {
1764 bool first = true;
1765 auto printEntries = [&](OperandRange vars, TypeRange types,
1766 std::optional<ArrayAttr> kinds) {
1767 for (unsigned i = 0, e = vars.size(); i < e; ++i) {
1768 if (!first)
1769 p << ", ";
1770 p << stringifyClauseTaskDepend(
1771 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1772 .getValue())
1773 << " -> " << vars[i] << " : " << types[i];
1774 first = false;
1775 }
1776 };
1777 printEntries(dependVars, dependTypes, dependKinds);
1778 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1779}
1780
1781/// Verifies Depend clause
1782static LogicalResult verifyDependVarList(Operation *op,
1783 std::optional<ArrayAttr> dependKinds,
1784 OperandRange dependVars,
1785 std::optional<ArrayAttr> iteratedKinds,
1786 OperandRange iteratedVars) {
1787 if (!dependVars.empty()) {
1788 if (!dependKinds || dependKinds->size() != dependVars.size())
1789 return op->emitOpError() << "expected as many depend values"
1790 " as depend variables";
1791 } else {
1792 if (dependKinds && !dependKinds->empty())
1793 return op->emitOpError() << "unexpected depend values";
1794 }
1795
1796 if (!iteratedVars.empty()) {
1797 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1798 return op->emitOpError() << "expected as many depend iterated values"
1799 " as depend iterated variables";
1800 } else {
1801 if (iteratedKinds && !iteratedKinds->empty())
1802 return op->emitOpError() << "unexpected depend iterated values";
1803 }
1804
1805 return success();
1806}
1807
1808//===----------------------------------------------------------------------===//
1809// Parser, printer and verifier for Synchronization Hint (2.17.12)
1810//===----------------------------------------------------------------------===//
1811
1812/// Parses a Synchronization Hint clause. The value of hint is an integer
1813/// which is a combination of different hints from `omp_sync_hint_t`.
1814///
1815/// hint-clause = `hint` `(` hint-value `)`
1816static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1817 IntegerAttr &hintAttr) {
1818 StringRef hintKeyword;
1819 int64_t hint = 0;
1820 if (succeeded(parser.parseOptionalKeyword("none"))) {
1821 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1822 return success();
1823 }
1824 auto parseKeyword = [&]() -> ParseResult {
1825 if (failed(parser.parseKeyword(&hintKeyword)))
1826 return failure();
1827 if (hintKeyword == "uncontended")
1828 hint |= 1;
1829 else if (hintKeyword == "contended")
1830 hint |= 2;
1831 else if (hintKeyword == "nonspeculative")
1832 hint |= 4;
1833 else if (hintKeyword == "speculative")
1834 hint |= 8;
1835 else
1836 return parser.emitError(parser.getCurrentLocation())
1837 << hintKeyword << " is not a valid hint";
1838 return success();
1839 };
1840 if (parser.parseCommaSeparatedList(parseKeyword))
1841 return failure();
1842 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
1843 return success();
1844}
1845
1846/// Prints a Synchronization Hint clause
1848 IntegerAttr hintAttr) {
1849 int64_t hint = hintAttr.getInt();
1850
1851 if (hint == 0) {
1852 p << "none";
1853 return;
1854 }
1855
1856 // Helper function to get n-th bit from the right end of `value`
1857 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1858
1859 bool uncontended = bitn(hint, 0);
1860 bool contended = bitn(hint, 1);
1861 bool nonspeculative = bitn(hint, 2);
1862 bool speculative = bitn(hint, 3);
1863
1865 if (uncontended)
1866 hints.push_back("uncontended");
1867 if (contended)
1868 hints.push_back("contended");
1869 if (nonspeculative)
1870 hints.push_back("nonspeculative");
1871 if (speculative)
1872 hints.push_back("speculative");
1873
1874 llvm::interleaveComma(hints, p);
1875}
1876
1877/// Verifies a synchronization hint clause
1878static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
1879
1880 // Helper function to get n-th bit from the right end of `value`
1881 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
1882
1883 bool uncontended = bitn(hint, 0);
1884 bool contended = bitn(hint, 1);
1885 bool nonspeculative = bitn(hint, 2);
1886 bool speculative = bitn(hint, 3);
1887
1888 if (uncontended && contended)
1889 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
1890 "omp_sync_hint_contended cannot be combined";
1891 if (nonspeculative && speculative)
1892 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
1893 "omp_sync_hint_speculative cannot be combined.";
1894 return success();
1895}
1896
1897//===----------------------------------------------------------------------===//
1898// Parser, printer and verifier for Target
1899//===----------------------------------------------------------------------===//
1900
1901// Helper function to get bitwise AND of `value` and 'flag' then return it as a
1902// boolean
1903static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
1904 return (value & flag) == flag;
1905}
1906
1907/// Parses a map_entries map type from a string format back into its numeric
1908/// value.
1909///
1910/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
1911/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
1912static ParseResult parseMapClause(OpAsmParser &parser,
1913 ClauseMapFlagsAttr &mapType) {
1914 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
1915 // This simply verifies the correct keyword is read in, the
1916 // keyword itself is stored inside of the operation
1917 auto parseTypeAndMod = [&]() -> ParseResult {
1918 StringRef mapTypeMod;
1919 if (parser.parseKeyword(&mapTypeMod))
1920 return failure();
1921
1922 if (mapTypeMod == "always")
1923 mapTypeBits |= ClauseMapFlags::always;
1924
1925 if (mapTypeMod == "implicit")
1926 mapTypeBits |= ClauseMapFlags::implicit;
1927
1928 if (mapTypeMod == "ompx_hold")
1929 mapTypeBits |= ClauseMapFlags::ompx_hold;
1930
1931 if (mapTypeMod == "close")
1932 mapTypeBits |= ClauseMapFlags::close;
1933
1934 if (mapTypeMod == "present")
1935 mapTypeBits |= ClauseMapFlags::present;
1936
1937 if (mapTypeMod == "to")
1938 mapTypeBits |= ClauseMapFlags::to;
1939
1940 if (mapTypeMod == "from")
1941 mapTypeBits |= ClauseMapFlags::from;
1942
1943 if (mapTypeMod == "tofrom")
1944 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
1945
1946 if (mapTypeMod == "delete")
1947 mapTypeBits |= ClauseMapFlags::del;
1948
1949 if (mapTypeMod == "storage")
1950 mapTypeBits |= ClauseMapFlags::storage;
1951
1952 if (mapTypeMod == "return_param")
1953 mapTypeBits |= ClauseMapFlags::return_param;
1954
1955 if (mapTypeMod == "private")
1956 mapTypeBits |= ClauseMapFlags::priv;
1957
1958 if (mapTypeMod == "literal")
1959 mapTypeBits |= ClauseMapFlags::literal;
1960
1961 if (mapTypeMod == "attach")
1962 mapTypeBits |= ClauseMapFlags::attach;
1963
1964 if (mapTypeMod == "attach_always")
1965 mapTypeBits |= ClauseMapFlags::attach_always;
1966
1967 if (mapTypeMod == "attach_never")
1968 mapTypeBits |= ClauseMapFlags::attach_never;
1969
1970 if (mapTypeMod == "attach_auto")
1971 mapTypeBits |= ClauseMapFlags::attach_auto;
1972
1973 if (mapTypeMod == "ref_ptr")
1974 mapTypeBits |= ClauseMapFlags::ref_ptr;
1975
1976 if (mapTypeMod == "ref_ptee")
1977 mapTypeBits |= ClauseMapFlags::ref_ptee;
1978
1979 if (mapTypeMod == "ref_ptr_ptee")
1980 mapTypeBits |= ClauseMapFlags::ref_ptr_ptee;
1981
1982 if (mapTypeMod == "is_device_ptr")
1983 mapTypeBits |= ClauseMapFlags::is_device_ptr;
1984
1985 return success();
1986 };
1987
1988 if (parser.parseCommaSeparatedList(parseTypeAndMod))
1989 return failure();
1990
1991 mapType =
1992 parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
1993
1994 return success();
1995}
1996
1997/// Prints a map_entries map type from its numeric value out into its string
1998/// format.
1999static void printMapClause(OpAsmPrinter &p, Operation *op,
2000 ClauseMapFlagsAttr mapType) {
2002 ClauseMapFlags mapFlags = mapType.getValue();
2003
2004 // handling of always, close, present placed at the beginning of the string
2005 // to aid readability
2006 if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
2007 mapTypeStrs.push_back("always");
2008 if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
2009 mapTypeStrs.push_back("implicit");
2010 if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
2011 mapTypeStrs.push_back("ompx_hold");
2012 if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
2013 mapTypeStrs.push_back("close");
2014 if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
2015 mapTypeStrs.push_back("present");
2016
2017 // special handling of to/from/tofrom/delete and release/alloc, release +
2018 // alloc are the abscense of one of the other flags, whereas tofrom requires
2019 // both the to and from flag to be set.
2020 bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
2021 bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
2022
2023 if (to && from)
2024 mapTypeStrs.push_back("tofrom");
2025 else if (from)
2026 mapTypeStrs.push_back("from");
2027 else if (to)
2028 mapTypeStrs.push_back("to");
2029
2030 if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
2031 mapTypeStrs.push_back("delete");
2032 if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
2033 mapTypeStrs.push_back("return_param");
2034 if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
2035 mapTypeStrs.push_back("storage");
2036 if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
2037 mapTypeStrs.push_back("private");
2038 if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
2039 mapTypeStrs.push_back("literal");
2040 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
2041 mapTypeStrs.push_back("attach");
2042 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
2043 mapTypeStrs.push_back("attach_always");
2044 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_never))
2045 mapTypeStrs.push_back("attach_never");
2046 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
2047 mapTypeStrs.push_back("attach_auto");
2048 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
2049 mapTypeStrs.push_back("ref_ptr");
2050 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
2051 mapTypeStrs.push_back("ref_ptee");
2052 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr_ptee))
2053 mapTypeStrs.push_back("ref_ptr_ptee");
2054 if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
2055 mapTypeStrs.push_back("is_device_ptr");
2056 if (mapFlags == ClauseMapFlags::none)
2057 mapTypeStrs.push_back("none");
2058
2059 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2060 p << mapTypeStrs[i];
2061 if (i + 1 < mapTypeStrs.size()) {
2062 p << ", ";
2063 }
2064 }
2065}
2066
2067static ParseResult parseMembersIndex(OpAsmParser &parser,
2068 ArrayAttr &membersIdx) {
2069 SmallVector<Attribute> values, memberIdxs;
2070
2071 auto parseIndices = [&]() -> ParseResult {
2072 int64_t value;
2073 if (parser.parseInteger(value))
2074 return failure();
2075 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
2076 APInt(64, value, /*isSigned=*/false)));
2077 return success();
2078 };
2079
2080 do {
2081 if (failed(parser.parseLSquare()))
2082 return failure();
2083
2084 if (parser.parseCommaSeparatedList(parseIndices))
2085 return failure();
2086
2087 if (failed(parser.parseRSquare()))
2088 return failure();
2089
2090 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
2091 values.clear();
2092 } while (succeeded(parser.parseOptionalComma()));
2093
2094 if (!memberIdxs.empty())
2095 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
2096
2097 return success();
2098}
2099
2100static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
2101 ArrayAttr membersIdx) {
2102 if (!membersIdx)
2103 return;
2104
2105 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
2106 p << "[";
2107 auto memberIdx = cast<ArrayAttr>(v);
2108 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
2109 p << cast<IntegerAttr>(v2).getInt();
2110 });
2111 p << "]";
2112 });
2113}
2114
2116 VariableCaptureKindAttr mapCaptureType) {
2117 std::string typeCapStr;
2118 llvm::raw_string_ostream typeCap(typeCapStr);
2119 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2120 typeCap << "ByRef";
2121 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2122 typeCap << "ByCopy";
2123 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2124 typeCap << "VLAType";
2125 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2126 typeCap << "This";
2127 p << typeCapStr;
2128}
2129
2130static ParseResult parseCaptureType(OpAsmParser &parser,
2131 VariableCaptureKindAttr &mapCaptureType) {
2132 StringRef mapCaptureKey;
2133 if (parser.parseKeyword(&mapCaptureKey))
2134 return failure();
2135
2136 if (mapCaptureKey == "This")
2137 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2138 parser.getContext(), mlir::omp::VariableCaptureKind::This);
2139 if (mapCaptureKey == "ByRef")
2140 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2141 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
2142 if (mapCaptureKey == "ByCopy")
2143 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2144 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2145 if (mapCaptureKey == "VLAType")
2146 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2147 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
2148
2149 return success();
2150}
2151
2152static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars) {
2155
2156 for (auto mapOp : mapVars) {
2157 if (!mapOp.getDefiningOp())
2158 return emitError(op->getLoc(), "missing map operation");
2159
2160 if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2161 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2162
2163 bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
2164 bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2165 bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
2166
2167 bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2168 bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2169 bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2170
2171 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2172 return emitError(op->getLoc(),
2173 "to, from, tofrom and alloc map types are permitted");
2174
2175 if (isa<TargetEnterDataOp>(op) && (from || del))
2176 return emitError(op->getLoc(), "to and alloc map types are permitted");
2177
2178 if (isa<TargetExitDataOp>(op) && to)
2179 return emitError(op->getLoc(),
2180 "from, release and delete map types are permitted");
2181
2182 if (isa<TargetUpdateOp>(op)) {
2183 if (del) {
2184 return emitError(op->getLoc(),
2185 "at least one of to or from map types must be "
2186 "specified, other map types are not permitted");
2187 }
2188
2189 if (!to && !from) {
2190 return emitError(op->getLoc(),
2191 "at least one of to or from map types must be "
2192 "specified, other map types are not permitted");
2193 }
2194
2195 auto updateVar = mapInfoOp.getVarPtr();
2196
2197 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2198 (from && updateToVars.contains(updateVar))) {
2199 return emitError(
2200 op->getLoc(),
2201 "either to or from map types can be specified, not both");
2202 }
2203
2204 if (always || close || implicit) {
2205 return emitError(
2206 op->getLoc(),
2207 "present, mapper and iterator map type modifiers are permitted");
2208 }
2209
2210 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2211 }
2212 } else if (!isa<DeclareMapperInfoOp>(op)) {
2213 return emitError(op->getLoc(),
2214 "map argument is not a map entry operation");
2215 }
2216 }
2217
2218 return success();
2219}
2220
2221static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2222 std::optional<DenseI64ArrayAttr> privateMapIndices =
2223 targetOp.getPrivateMapsAttr();
2224
2225 // None of the private operands are mapped.
2226 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2227 return success();
2228
2229 OperandRange privateVars = targetOp.getPrivateVars();
2230
2231 if (privateMapIndices.value().size() !=
2232 static_cast<int64_t>(privateVars.size()))
2233 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2234 "`private_maps` attribute mismatch");
2235
2236 return success();
2237}
2238
2239//===----------------------------------------------------------------------===//
2240// MapInfoOp
2241//===----------------------------------------------------------------------===//
2242
2243static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2244 StringRef clauseName,
2245 OperandRange vars) {
2246 for (Value var : vars)
2247 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2248 return op->emitOpError()
2249 << "'" << clauseName
2250 << "' arguments must be defined by 'omp.map.info' ops";
2251 return success();
2252}
2253
2254LogicalResult MapInfoOp::verify() {
2255 if (getMapperId() &&
2257 *this, getMapperIdAttr())) {
2258 return emitError("invalid mapper id");
2259 }
2260
2261 if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2262 return failure();
2263
2264 return success();
2265}
2266
2267//===----------------------------------------------------------------------===//
2268// TargetDataOp
2269//===----------------------------------------------------------------------===//
2270
2271void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2272 const TargetDataOperands &clauses) {
2273 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2274 clauses.mapVars, clauses.useDeviceAddrVars,
2275 clauses.useDevicePtrVars);
2276}
2277
2278LogicalResult TargetDataOp::verify() {
2279 if (getMapVars().empty() && getUseDevicePtrVars().empty() &&
2280 getUseDeviceAddrVars().empty()) {
2281 return ::emitError(this->getLoc(),
2282 "At least one of map, use_device_ptr_vars, or "
2283 "use_device_addr_vars operand must be present");
2284 }
2285
2286 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2287 getUseDevicePtrVars())))
2288 return failure();
2289
2290 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2291 getUseDeviceAddrVars())))
2292 return failure();
2293
2294 return verifyMapClause(*this, getMapVars());
2295}
2296
2297//===----------------------------------------------------------------------===//
2298// TargetEnterDataOp
2299//===----------------------------------------------------------------------===//
2300
2301void TargetEnterDataOp::build(
2302 OpBuilder &builder, OperationState &state,
2303 const TargetEnterExitUpdateDataOperands &clauses) {
2304 MLIRContext *ctx = builder.getContext();
2305 TargetEnterDataOp::build(
2306 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2307 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2308 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2309 clauses.nowait);
2310}
2311
2312LogicalResult TargetEnterDataOp::verify() {
2313 LogicalResult verifyDependVars =
2314 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2315 getDependIteratedKinds(), getDependIterated());
2316 return failed(verifyDependVars) ? verifyDependVars
2317 : verifyMapClause(*this, getMapVars());
2318}
2319
2320//===----------------------------------------------------------------------===//
2321// TargetExitDataOp
2322//===----------------------------------------------------------------------===//
2323
2324void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2325 const TargetEnterExitUpdateDataOperands &clauses) {
2326 MLIRContext *ctx = builder.getContext();
2327 TargetExitDataOp::build(
2328 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2329 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2330 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2331 clauses.nowait);
2332}
2333
2334LogicalResult TargetExitDataOp::verify() {
2335 LogicalResult verifyDependVars =
2336 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2337 getDependIteratedKinds(), getDependIterated());
2338 return failed(verifyDependVars) ? verifyDependVars
2339 : verifyMapClause(*this, getMapVars());
2340}
2341
2342//===----------------------------------------------------------------------===//
2343// TargetUpdateOp
2344//===----------------------------------------------------------------------===//
2345
2346void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2347 const TargetEnterExitUpdateDataOperands &clauses) {
2348 MLIRContext *ctx = builder.getContext();
2349 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2350 clauses.dependVars,
2351 makeArrayAttr(ctx, clauses.dependIteratedKinds),
2352 clauses.dependIterated, clauses.device, clauses.ifExpr,
2353 clauses.mapVars, clauses.nowait);
2354}
2355
2356LogicalResult TargetUpdateOp::verify() {
2357 LogicalResult verifyDependVars =
2358 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2359 getDependIteratedKinds(), getDependIterated());
2360 return failed(verifyDependVars) ? verifyDependVars
2361 : verifyMapClause(*this, getMapVars());
2362}
2363
2364//===----------------------------------------------------------------------===//
2365// TargetOp
2366//===----------------------------------------------------------------------===//
2367
2368void TargetOp::build(OpBuilder &builder, OperationState &state,
2369 const TargetOperands &clauses) {
2370 MLIRContext *ctx = builder.getContext();
2371 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2372 // inReductionByref, inReductionSyms.
2373 TargetOp::build(
2374 builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare,
2375 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2376 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2377 clauses.device, clauses.hasDeviceAddrVars, clauses.hostEvalVars,
2378 clauses.ifExpr,
2379 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2380 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2381 clauses.nowait, clauses.privateVars,
2382 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2383 clauses.threadLimitVars,
2384 /*private_maps=*/nullptr);
2385}
2386
2387LogicalResult TargetOp::verify() {
2388 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars(),
2389 getDependIteratedKinds(),
2390 getDependIterated())))
2391 return failure();
2392
2393 if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2394 getHasDeviceAddrVars())))
2395 return failure();
2396
2397 if (failed(verifyMapClause(*this, getMapVars())))
2398 return failure();
2399
2400 return verifyPrivateVarsMapping(*this);
2401}
2402
2403LogicalResult TargetOp::verifyRegions() {
2404 auto teamsOps = getOps<TeamsOp>();
2405 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2406 return emitError("target containing multiple 'omp.teams' nested ops");
2407
2408 // Check that host_eval values are only used in legal ways.
2409 Operation *capturedOp = getInnermostCapturedOmpOp();
2410 TargetRegionFlags execFlags = getKernelExecFlags(capturedOp);
2411 for (Value hostEvalArg :
2412 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2413 for (Operation *user : hostEvalArg.getUsers()) {
2414 if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2415 // Check if used in num_teams_lower or any of num_teams_upper_vars
2416 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2417 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2418 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2419 continue;
2420
2421 return emitOpError() << "host_eval argument only legal as 'num_teams' "
2422 "and 'thread_limit' in 'omp.teams'";
2423 }
2424 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2425 if (bitEnumContainsAny(execFlags, TargetRegionFlags::spmd) &&
2426 parallelOp->isAncestor(capturedOp) &&
2427 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2428 continue;
2429
2430 return emitOpError()
2431 << "host_eval argument only legal as 'num_threads' in "
2432 "'omp.parallel' when representing target SPMD";
2433 }
2434 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2435 if (bitEnumContainsAny(execFlags, TargetRegionFlags::trip_count) &&
2436 loopNestOp.getOperation() == capturedOp &&
2437 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2438 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2439 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2440 continue;
2441
2442 return emitOpError() << "host_eval argument only legal as loop bounds "
2443 "and steps in 'omp.loop_nest' when trip count "
2444 "must be evaluated in the host";
2445 }
2446
2447 return emitOpError() << "host_eval argument illegal use in '"
2448 << user->getName() << "' operation";
2449 }
2450 }
2451 return success();
2452}
2453
2454static Operation *
2455findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2456 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2457 assert(rootOp && "expected valid operation");
2458
2459 Dialect *ompDialect = rootOp->getDialect();
2460 Operation *capturedOp = nullptr;
2461 DominanceInfo domInfo;
2462
2463 // Process in pre-order to check operations from outermost to innermost,
2464 // ensuring we only enter the region of an operation if it meets the criteria
2465 // for being captured. We stop the exploration of nested operations as soon as
2466 // we process a region holding no operations to be captured.
2467 rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2468 if (op == rootOp)
2469 return WalkResult::advance();
2470
2471 // Ignore operations of other dialects or omp operations with no regions,
2472 // because these will only be checked if they are siblings of an omp
2473 // operation that can potentially be captured.
2474 bool isOmpDialect = op->getDialect() == ompDialect;
2475 bool hasRegions = op->getNumRegions() > 0;
2476 if (!isOmpDialect || !hasRegions)
2477 return WalkResult::skip();
2478
2479 // This operation cannot be captured if it can be executed more than once
2480 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2481 // be executed before all exits of the region (i.e. it doesn't dominate all
2482 // blocks with no successors reachable from the entry block).
2483 if (checkSingleMandatoryExec) {
2484 Region *parentRegion = op->getParentRegion();
2485 Block *parentBlock = op->getBlock();
2486
2487 for (Block *successor : parentBlock->getSuccessors())
2488 if (successor->isReachable(parentBlock))
2489 return WalkResult::interrupt();
2490
2491 for (Block &block : *parentRegion)
2492 if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2493 !domInfo.dominates(parentBlock, &block))
2494 return WalkResult::interrupt();
2495 }
2496
2497 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2498 // into nested operations.
2499 for (Operation &sibling : op->getParentRegion()->getOps())
2500 if (&sibling != op && !siblingAllowedFn(&sibling))
2501 return WalkResult::interrupt();
2502
2503 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2504 // Otherwise, process the contents of this operation.
2505 capturedOp = op;
2506 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2508 });
2509
2510 return capturedOp;
2511}
2512
2513Operation *TargetOp::getInnermostCapturedOmpOp() {
2514 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2515
2516 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2517 // effects, but don't include a memory write effect.
2518 return findCapturedOmpOp(
2519 *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2520 if (!sibling)
2521 return false;
2522
2523 if (ompDialect == sibling->getDialect())
2524 return sibling->hasTrait<OpTrait::IsTerminator>();
2525
2526 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2528 effects;
2529 memOp.getEffects(effects);
2530 return !llvm::any_of(
2531 effects, [&](MemoryEffects::EffectInstance &effect) {
2532 return isa<MemoryEffects::Write>(effect.getEffect()) &&
2533 isa<SideEffects::AutomaticAllocationScopeResource>(
2534 effect.getResource());
2535 });
2536 }
2537 return true;
2538 });
2539}
2540
2541/// Check if we can promote SPMD kernel to No-Loop kernel.
2542static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2543 WsloopOp *wsLoopOp) {
2544 // num_teams clause can break no-loop teams/threads assumption.
2545 if (!teamsOp.getNumTeamsUpperVars().empty())
2546 return false;
2547
2548 // Reduction kernels are slower in no-loop mode.
2549 if (teamsOp.getNumReductionVars())
2550 return false;
2551 if (wsLoopOp->getNumReductionVars())
2552 return false;
2553
2554 // Check if the user allows the promotion of kernels to no-loop mode.
2555 OffloadModuleInterface offloadMod =
2556 capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2557 if (!offloadMod)
2558 return false;
2559 auto ompFlags = offloadMod.getFlags();
2560 if (!ompFlags)
2561 return false;
2562 return ompFlags.getAssumeTeamsOversubscription() &&
2563 ompFlags.getAssumeThreadsOversubscription();
2564}
2565
2566TargetRegionFlags TargetOp::getKernelExecFlags(Operation *capturedOp) {
2567 // A non-null captured op is only valid if it resides inside of a TargetOp
2568 // and is the result of calling getInnermostCapturedOmpOp() on it.
2569 TargetOp targetOp =
2570 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2571 assert((!capturedOp ||
2572 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2573 "unexpected captured op");
2574
2575 // If it's not capturing a loop, it's a default target region.
2576 if (!isa_and_present<LoopNestOp>(capturedOp))
2577 return TargetRegionFlags::generic;
2578
2579 // Get the innermost non-simd loop wrapper.
2581 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2582 assert(!loopWrappers.empty());
2583
2584 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2585 if (isa<SimdOp>(innermostWrapper))
2586 innermostWrapper = std::next(innermostWrapper);
2587
2588 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2589 if (numWrappers != 1 && numWrappers != 2)
2590 return TargetRegionFlags::generic;
2591
2592 // Detect target-teams-distribute-parallel-wsloop[-simd].
2593 if (numWrappers == 2) {
2594 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2595 if (!wsloopOp)
2596 return TargetRegionFlags::generic;
2597
2598 innermostWrapper = std::next(innermostWrapper);
2599 if (!isa<DistributeOp>(innermostWrapper))
2600 return TargetRegionFlags::generic;
2601
2602 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2603 if (!isa_and_present<ParallelOp>(parallelOp))
2604 return TargetRegionFlags::generic;
2605
2606 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2607 if (!teamsOp)
2608 return TargetRegionFlags::generic;
2609
2610 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2611 TargetRegionFlags result =
2612 TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2613 if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2614 result = result | TargetRegionFlags::no_loop;
2615 return result;
2616 }
2617 }
2618 // Detect target-teams-distribute[-simd] and target-teams-loop.
2619 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2620 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2621 if (!isa_and_present<TeamsOp>(teamsOp))
2622 return TargetRegionFlags::generic;
2623
2624 if (teamsOp->getParentOp() != targetOp.getOperation())
2625 return TargetRegionFlags::generic;
2626
2627 if (isa<LoopOp>(innermostWrapper))
2628 return TargetRegionFlags::spmd | TargetRegionFlags::trip_count;
2629
2630 // Find single immediately nested captured omp.parallel and add spmd flag
2631 // (generic-spmd case).
2632 //
2633 // TODO: This shouldn't have to be done here, as it is too easy to break.
2634 // The openmp-opt pass should be updated to be able to promote kernels like
2635 // this from "Generic" to "Generic-SPMD". However, the use of the
2636 // `kmpc_distribute_static_loop` family of functions produced by the
2637 // OMPIRBuilder for these kernels prevents that from working.
2638 Dialect *ompDialect = targetOp->getDialect();
2639 Operation *nestedCapture = findCapturedOmpOp(
2640 capturedOp, /*checkSingleMandatoryExec=*/false,
2641 [&](Operation *sibling) {
2642 return sibling && (ompDialect != sibling->getDialect() ||
2643 sibling->hasTrait<OpTrait::IsTerminator>());
2644 });
2645
2646 TargetRegionFlags result =
2647 TargetRegionFlags::generic | TargetRegionFlags::trip_count;
2648
2649 if (!nestedCapture)
2650 return result;
2651
2652 while (nestedCapture->getParentOp() != capturedOp)
2653 nestedCapture = nestedCapture->getParentOp();
2654
2655 return isa<ParallelOp>(nestedCapture) ? result | TargetRegionFlags::spmd
2656 : result;
2657 }
2658 // Detect target-parallel-wsloop[-simd].
2659 else if (isa<WsloopOp>(innermostWrapper)) {
2660 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2661 if (!isa_and_present<ParallelOp>(parallelOp))
2662 return TargetRegionFlags::generic;
2663
2664 if (parallelOp->getParentOp() == targetOp.getOperation())
2665 return TargetRegionFlags::spmd;
2666 }
2667
2668 return TargetRegionFlags::generic;
2669}
2670
2671//===----------------------------------------------------------------------===//
2672// ParallelOp
2673//===----------------------------------------------------------------------===//
2674
2675void ParallelOp::build(OpBuilder &builder, OperationState &state,
2676 ArrayRef<NamedAttribute> attributes) {
2677 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2678 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2679 /*num_threads_vars=*/ValueRange(),
2680 /*private_vars=*/ValueRange(),
2681 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2682 /*proc_bind_kind=*/nullptr,
2683 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2684 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2685 state.addAttributes(attributes);
2686}
2687
2688void ParallelOp::build(OpBuilder &builder, OperationState &state,
2689 const ParallelOperands &clauses) {
2690 MLIRContext *ctx = builder.getContext();
2691 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2692 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2693 makeArrayAttr(ctx, clauses.privateSyms),
2694 clauses.privateNeedsBarrier, clauses.procBindKind,
2695 clauses.reductionMod, clauses.reductionVars,
2696 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2697 makeArrayAttr(ctx, clauses.reductionSyms));
2698}
2699
2700template <typename OpType>
2701static LogicalResult verifyPrivateVarList(OpType &op) {
2702 auto privateVars = op.getPrivateVars();
2703 auto privateSyms = op.getPrivateSymsAttr();
2704
2705 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2706 return success();
2707
2708 auto numPrivateVars = privateVars.size();
2709 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2710
2711 if (numPrivateVars != numPrivateSyms)
2712 return op.emitError() << "inconsistent number of private variables and "
2713 "privatizer op symbols, private vars: "
2714 << numPrivateVars
2715 << " vs. privatizer op symbols: " << numPrivateSyms;
2716
2717 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2718 Type varType = std::get<0>(privateVarInfo).getType();
2719 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2720 PrivateClauseOp privatizerOp =
2722
2723 if (privatizerOp == nullptr)
2724 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2725 << privateSym << "'";
2726
2727 Type privatizerType = privatizerOp.getArgType();
2728
2729 if (privatizerType && (varType != privatizerType))
2730 return op.emitError()
2731 << "type mismatch between a "
2732 << (privatizerOp.getDataSharingType() ==
2733 DataSharingClauseType::Private
2734 ? "private"
2735 : "firstprivate")
2736 << " variable and its privatizer op, var type: " << varType
2737 << " vs. privatizer op type: " << privatizerType;
2738 }
2739
2740 return success();
2741}
2742
2743LogicalResult ParallelOp::verify() {
2744 if (getAllocateVars().size() != getAllocatorVars().size())
2745 return emitError(
2746 "expected equal sizes for allocate and allocator variables");
2747
2748 if (failed(verifyPrivateVarList(*this)))
2749 return failure();
2750
2751 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2752 getReductionByref());
2753}
2754
2755LogicalResult ParallelOp::verifyRegions() {
2756 auto distChildOps = getOps<DistributeOp>();
2757 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2758 if (numDistChildOps > 1)
2759 return emitError()
2760 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2761
2762 if (numDistChildOps == 1) {
2763 if (!isComposite())
2764 return emitError()
2765 << "'omp.composite' attribute missing from composite operation";
2766
2767 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2768 Operation &distributeOp = **distChildOps.begin();
2769 for (Operation &childOp : getOps()) {
2770 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2771 continue;
2772
2773 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2774 return emitError() << "unexpected OpenMP operation inside of composite "
2775 "'omp.parallel': "
2776 << childOp.getName();
2777 }
2778 } else if (isComposite()) {
2779 return emitError()
2780 << "'omp.composite' attribute present in non-composite operation";
2781 }
2782 return success();
2783}
2784
2785//===----------------------------------------------------------------------===//
2786// TeamsOp
2787//===----------------------------------------------------------------------===//
2788
2790 while ((op = op->getParentOp()))
2791 if (isa<OpenMPDialect>(op->getDialect()))
2792 return false;
2793 return true;
2794}
2795
2796void TeamsOp::build(OpBuilder &builder, OperationState &state,
2797 const TeamsOperands &clauses) {
2798 MLIRContext *ctx = builder.getContext();
2799 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2800 TeamsOp::build(
2801 builder, state, clauses.allocateVars, clauses.allocatorVars,
2802 clauses.ifExpr, clauses.numTeamsLower, clauses.numTeamsUpperVars,
2803 /*private_vars=*/{}, /*private_syms=*/nullptr,
2804 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
2805 clauses.reductionVars,
2806 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2807 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
2808}
2809
2810// Verify num_teams clause
2811static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower,
2812 OperandRange numTeamsUpperVars) {
2813 // If lower is specified, upper must have exactly one value
2814 if (numTeamsLower) {
2815 if (numTeamsUpperVars.size() != 1)
2816 return op->emitError(
2817 "expected exactly one num_teams upper bound when lower bound is "
2818 "specified");
2819 if (numTeamsLower.getType() != numTeamsUpperVars[0].getType())
2820 return op->emitError(
2821 "expected num_teams upper bound and lower bound to be "
2822 "the same type");
2823 }
2824
2825 return success();
2826}
2827
2828LogicalResult TeamsOp::verify() {
2829 // Check parent region
2830 // TODO If nested inside of a target region, also check that it does not
2831 // contain any statements, declarations or directives other than this
2832 // omp.teams construct. The issue is how to support the initialization of
2833 // this operation's own arguments (allow SSA values across omp.target?).
2834 Operation *op = getOperation();
2835 if (!isa<TargetOp>(op->getParentOp()) &&
2837 return emitError("expected to be nested inside of omp.target or not nested "
2838 "in any OpenMP dialect operations");
2839
2840 // Check for num_teams clause restrictions
2841 if (failed(verifyNumTeamsClause(op, this->getNumTeamsLower(),
2842 this->getNumTeamsUpperVars())))
2843 return failure();
2844
2845 // Check for allocate clause restrictions
2846 if (getAllocateVars().size() != getAllocatorVars().size())
2847 return emitError(
2848 "expected equal sizes for allocate and allocator variables");
2849
2850 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2851 getReductionByref());
2852}
2853
2854//===----------------------------------------------------------------------===//
2855// SectionOp
2856//===----------------------------------------------------------------------===//
2857
2858OperandRange SectionOp::getPrivateVars() {
2859 return getParentOp().getPrivateVars();
2860}
2861
2862OperandRange SectionOp::getReductionVars() {
2863 return getParentOp().getReductionVars();
2864}
2865
2866//===----------------------------------------------------------------------===//
2867// SectionsOp
2868//===----------------------------------------------------------------------===//
2869
2870void SectionsOp::build(OpBuilder &builder, OperationState &state,
2871 const SectionsOperands &clauses) {
2872 MLIRContext *ctx = builder.getContext();
2873 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2874 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2875 clauses.nowait, /*private_vars=*/{},
2876 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2877 clauses.reductionMod, clauses.reductionVars,
2878 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2879 makeArrayAttr(ctx, clauses.reductionSyms));
2880}
2881
2882LogicalResult SectionsOp::verify() {
2883 if (getAllocateVars().size() != getAllocatorVars().size())
2884 return emitError(
2885 "expected equal sizes for allocate and allocator variables");
2886
2887 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2888 getReductionByref());
2889}
2890
2891LogicalResult SectionsOp::verifyRegions() {
2892 for (auto &inst : *getRegion().begin()) {
2893 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
2894 return emitOpError()
2895 << "expected omp.section op or terminator op inside region";
2896 }
2897 }
2898
2899 return success();
2900}
2901
2902//===----------------------------------------------------------------------===//
2903// SingleOp
2904//===----------------------------------------------------------------------===//
2905
2906void SingleOp::build(OpBuilder &builder, OperationState &state,
2907 const SingleOperands &clauses) {
2908 MLIRContext *ctx = builder.getContext();
2909 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
2910 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2911 clauses.copyprivateVars,
2912 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
2913 /*private_vars=*/{}, /*private_syms=*/nullptr,
2914 /*private_needs_barrier=*/nullptr);
2915}
2916
2917LogicalResult SingleOp::verify() {
2918 // Check for allocate clause restrictions
2919 if (getAllocateVars().size() != getAllocatorVars().size())
2920 return emitError(
2921 "expected equal sizes for allocate and allocator variables");
2922
2923 return verifyCopyprivateVarList(*this, getCopyprivateVars(),
2924 getCopyprivateSyms());
2925}
2926
2927//===----------------------------------------------------------------------===//
2928// WorkshareOp
2929//===----------------------------------------------------------------------===//
2930
2931void WorkshareOp::build(OpBuilder &builder, OperationState &state,
2932 const WorkshareOperands &clauses) {
2933 WorkshareOp::build(builder, state, clauses.nowait);
2934}
2935
2936//===----------------------------------------------------------------------===//
2937// WorkshareLoopWrapperOp
2938//===----------------------------------------------------------------------===//
2939
2940LogicalResult WorkshareLoopWrapperOp::verify() {
2941 if (!(*this)->getParentOfType<WorkshareOp>())
2942 return emitOpError() << "must be nested in an omp.workshare";
2943 return success();
2944}
2945
2946LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
2947 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
2948 getNestedWrapper())
2949 return emitOpError() << "expected to be a standalone loop wrapper";
2950
2951 return success();
2952}
2953
2954//===----------------------------------------------------------------------===//
2955// LoopWrapperInterface
2956//===----------------------------------------------------------------------===//
2957
2958LogicalResult LoopWrapperInterface::verifyImpl() {
2959 Operation *op = this->getOperation();
2960 if (!op->hasTrait<OpTrait::NoTerminator>() ||
2962 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
2963 "and `SingleBlock` traits";
2964
2965 if (op->getNumRegions() != 1)
2966 return emitOpError() << "loop wrapper does not contain exactly one region";
2967
2968 Region &region = op->getRegion(0);
2969 if (range_size(region.getOps()) != 1)
2970 return emitOpError()
2971 << "loop wrapper does not contain exactly one nested op";
2972
2973 Operation &firstOp = *region.op_begin();
2974 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
2975 return emitOpError() << "nested in loop wrapper is not another loop "
2976 "wrapper or `omp.loop_nest`";
2977
2978 return success();
2979}
2980
2981//===----------------------------------------------------------------------===//
2982// LoopOp
2983//===----------------------------------------------------------------------===//
2984
2985void LoopOp::build(OpBuilder &builder, OperationState &state,
2986 const LoopOperands &clauses) {
2987 MLIRContext *ctx = builder.getContext();
2988
2989 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
2990 makeArrayAttr(ctx, clauses.privateSyms),
2991 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
2992 clauses.reductionMod, clauses.reductionVars,
2993 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2994 makeArrayAttr(ctx, clauses.reductionSyms));
2995}
2996
2997LogicalResult LoopOp::verify() {
2998 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2999 getReductionByref());
3000}
3001
3002LogicalResult LoopOp::verifyRegions() {
3003 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3004 getNestedWrapper())
3005 return emitOpError() << "expected to be a standalone loop wrapper";
3006
3007 return success();
3008}
3009
3010//===----------------------------------------------------------------------===//
3011// WsloopOp
3012//===----------------------------------------------------------------------===//
3013
3014void WsloopOp::build(OpBuilder &builder, OperationState &state,
3015 ArrayRef<NamedAttribute> attributes) {
3016 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
3017 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
3018 /*linear_var_types*/ nullptr, /*linear_modifiers=*/nullptr,
3019 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
3020 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
3021 /*private_needs_barrier=*/false,
3022 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
3023 /*reduction_byref=*/nullptr,
3024 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
3025 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
3026 /*schedule_simd=*/false);
3027 state.addAttributes(attributes);
3028}
3029
3030void WsloopOp::build(OpBuilder &builder, OperationState &state,
3031 const WsloopOperands &clauses) {
3032 MLIRContext *ctx = builder.getContext();
3033 // TODO: Store clauses in op: allocateVars, allocatorVars
3034 WsloopOp::build(
3035 builder, state,
3036 /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
3037 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3038 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3039 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3040 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3041 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3042 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3043 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3044}
3045
3046LogicalResult WsloopOp::verify() {
3047 if (failed(
3048 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3049 return failure();
3050 if (getLinearVars().size() &&
3051 getLinearVarTypes().value().size() != getLinearVars().size())
3052 return emitError() << "Ill-formed type attributes for linear variables";
3053 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3054 getReductionByref());
3055}
3056
3057LogicalResult WsloopOp::verifyRegions() {
3058 bool isCompositeChildLeaf =
3059 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3060
3061 if (LoopWrapperInterface nested = getNestedWrapper()) {
3062 if (!isComposite())
3063 return emitError()
3064 << "'omp.composite' attribute missing from composite wrapper";
3065
3066 // Check for the allowed leaf constructs that may appear in a composite
3067 // construct directly after DO/FOR.
3068 if (!isa<SimdOp>(nested))
3069 return emitError() << "only supported nested wrapper is 'omp.simd'";
3070
3071 } else if (isComposite() && !isCompositeChildLeaf) {
3072 return emitError()
3073 << "'omp.composite' attribute present in non-composite wrapper";
3074 } else if (!isComposite() && isCompositeChildLeaf) {
3075 return emitError()
3076 << "'omp.composite' attribute missing from composite wrapper";
3077 }
3078
3079 return success();
3080}
3081
3082//===----------------------------------------------------------------------===//
3083// Simd construct [2.9.3.1]
3084//===----------------------------------------------------------------------===//
3085
3086void SimdOp::build(OpBuilder &builder, OperationState &state,
3087 const SimdOperands &clauses) {
3088 MLIRContext *ctx = builder.getContext();
3089 SimdOp::build(builder, state, clauses.alignedVars,
3090 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
3091 clauses.linearVars, clauses.linearStepVars,
3092 clauses.linearVarTypes, clauses.linearModifiers,
3093 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3094 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3095 clauses.privateNeedsBarrier, clauses.reductionMod,
3096 clauses.reductionVars,
3097 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3098 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
3099 clauses.simdlen);
3100}
3101
3102LogicalResult SimdOp::verify() {
3103 if (getSimdlen().has_value() && getSafelen().has_value() &&
3104 getSimdlen().value() > getSafelen().value())
3105 return emitOpError()
3106 << "simdlen clause and safelen clause are both present, but the "
3107 "simdlen value is not less than or equal to safelen value";
3108
3109 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
3110 return failure();
3111
3112 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
3113 return failure();
3114
3115 if (failed(
3116 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3117 return failure();
3118
3119 bool isCompositeChildLeaf =
3120 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3121
3122 if (!isComposite() && isCompositeChildLeaf)
3123 return emitError()
3124 << "'omp.composite' attribute missing from composite wrapper";
3125
3126 if (isComposite() && !isCompositeChildLeaf)
3127 return emitError()
3128 << "'omp.composite' attribute present in non-composite wrapper";
3129
3130 // Firstprivate is not allowed for SIMD in the standard. Check that none of
3131 // the private decls are for firstprivate.
3132 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3133 if (privateSyms) {
3134 for (const Attribute &sym : *privateSyms) {
3135 auto symRef = cast<SymbolRefAttr>(sym);
3136 omp::PrivateClauseOp privatizer =
3138 getOperation(), symRef);
3139 if (!privatizer)
3140 return emitError() << "Cannot find privatizer '" << symRef << "'";
3141 if (privatizer.getDataSharingType() ==
3142 DataSharingClauseType::FirstPrivate)
3143 return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
3144 }
3145 }
3146
3147 if (getLinearVars().size() &&
3148 getLinearVarTypes().value().size() != getLinearVars().size())
3149 return emitError() << "Ill-formed type attributes for linear variables";
3150 return success();
3151}
3152
3153LogicalResult SimdOp::verifyRegions() {
3154 if (getNestedWrapper())
3155 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
3156
3157 return success();
3158}
3159
3160//===----------------------------------------------------------------------===//
3161// Distribute construct [2.9.4.1]
3162//===----------------------------------------------------------------------===//
3163
3164void DistributeOp::build(OpBuilder &builder, OperationState &state,
3165 const DistributeOperands &clauses) {
3166 DistributeOp::build(builder, state, clauses.allocateVars,
3167 clauses.allocatorVars, clauses.distScheduleStatic,
3168 clauses.distScheduleChunkSize, clauses.order,
3169 clauses.orderMod, clauses.privateVars,
3170 makeArrayAttr(builder.getContext(), clauses.privateSyms),
3171 clauses.privateNeedsBarrier);
3172}
3173
3174LogicalResult DistributeOp::verify() {
3175 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3176 return emitOpError() << "chunk size set without "
3177 "dist_schedule_static being present";
3178
3179 if (getAllocateVars().size() != getAllocatorVars().size())
3180 return emitError(
3181 "expected equal sizes for allocate and allocator variables");
3182
3183 return success();
3184}
3185
3186LogicalResult DistributeOp::verifyRegions() {
3187 if (LoopWrapperInterface nested = getNestedWrapper()) {
3188 if (!isComposite())
3189 return emitError()
3190 << "'omp.composite' attribute missing from composite wrapper";
3191 // Check for the allowed leaf constructs that may appear in a composite
3192 // construct directly after DISTRIBUTE.
3193 if (isa<WsloopOp>(nested)) {
3194 Operation *parentOp = (*this)->getParentOp();
3195 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3196 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3197 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
3198 "when a composite 'omp.parallel' is the direct "
3199 "parent";
3200 }
3201 } else if (!isa<SimdOp>(nested))
3202 return emitError() << "only supported nested wrappers are 'omp.simd' and "
3203 "'omp.wsloop'";
3204 } else if (isComposite()) {
3205 return emitError()
3206 << "'omp.composite' attribute present in non-composite wrapper";
3207 }
3208
3209 return success();
3210}
3211
3212//===----------------------------------------------------------------------===//
3213// DeclareMapperOp / DeclareMapperInfoOp
3214//===----------------------------------------------------------------------===//
3215
3216LogicalResult DeclareMapperInfoOp::verify() {
3217 return verifyMapClause(*this, getMapVars());
3218}
3219
3220LogicalResult DeclareMapperOp::verifyRegions() {
3221 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3222 getRegion().getBlocks().front().getTerminator()))
3223 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
3224
3225 return success();
3226}
3227
3228//===----------------------------------------------------------------------===//
3229// DeclareReductionOp
3230//===----------------------------------------------------------------------===//
3231
3232LogicalResult DeclareReductionOp::verifyRegions() {
3233 if (!getAllocRegion().empty()) {
3234 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3235 if (yieldOp.getResults().size() != 1 ||
3236 yieldOp.getResults().getTypes()[0] != getType())
3237 return emitOpError() << "expects alloc region to yield a value "
3238 "of the reduction type";
3239 }
3240 }
3241
3242 if (getInitializerRegion().empty())
3243 return emitOpError() << "expects non-empty initializer region";
3244 Block &initializerEntryBlock = getInitializerRegion().front();
3245
3246 if (initializerEntryBlock.getNumArguments() == 1) {
3247 if (!getAllocRegion().empty())
3248 return emitOpError() << "expects two arguments to the initializer region "
3249 "when an allocation region is used";
3250 } else if (initializerEntryBlock.getNumArguments() == 2) {
3251 if (getAllocRegion().empty())
3252 return emitOpError() << "expects one argument to the initializer region "
3253 "when no allocation region is used";
3254 } else {
3255 return emitOpError()
3256 << "expects one or two arguments to the initializer region";
3257 }
3258
3259 for (mlir::Value arg : initializerEntryBlock.getArguments())
3260 if (arg.getType() != getType())
3261 return emitOpError() << "expects initializer region argument to match "
3262 "the reduction type";
3263
3264 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3265 if (yieldOp.getResults().size() != 1 ||
3266 yieldOp.getResults().getTypes()[0] != getType())
3267 return emitOpError() << "expects initializer region to yield a value "
3268 "of the reduction type";
3269 }
3270
3271 if (getReductionRegion().empty())
3272 return emitOpError() << "expects non-empty reduction region";
3273 Block &reductionEntryBlock = getReductionRegion().front();
3274 if (reductionEntryBlock.getNumArguments() != 2 ||
3275 reductionEntryBlock.getArgumentTypes()[0] !=
3276 reductionEntryBlock.getArgumentTypes()[1] ||
3277 reductionEntryBlock.getArgumentTypes()[0] != getType())
3278 return emitOpError() << "expects reduction region with two arguments of "
3279 "the reduction type";
3280 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3281 if (yieldOp.getResults().size() != 1 ||
3282 yieldOp.getResults().getTypes()[0] != getType())
3283 return emitOpError() << "expects reduction region to yield a value "
3284 "of the reduction type";
3285 }
3286
3287 if (!getAtomicReductionRegion().empty()) {
3288 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3289 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3290 atomicReductionEntryBlock.getArgumentTypes()[0] !=
3291 atomicReductionEntryBlock.getArgumentTypes()[1])
3292 return emitOpError() << "expects atomic reduction region with two "
3293 "arguments of the same type";
3294 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3295 atomicReductionEntryBlock.getArgumentTypes()[0]);
3296 if (!ptrType ||
3297 (ptrType.getElementType() && ptrType.getElementType() != getType()))
3298 return emitOpError() << "expects atomic reduction region arguments to "
3299 "be accumulators containing the reduction type";
3300 }
3301
3302 if (getCleanupRegion().empty())
3303 return success();
3304 Block &cleanupEntryBlock = getCleanupRegion().front();
3305 if (cleanupEntryBlock.getNumArguments() != 1 ||
3306 cleanupEntryBlock.getArgument(0).getType() != getType())
3307 return emitOpError() << "expects cleanup region with one argument "
3308 "of the reduction type";
3309
3310 return success();
3311}
3312
3313//===----------------------------------------------------------------------===//
3314// TaskOp
3315//===----------------------------------------------------------------------===//
3316
3317void TaskOp::build(OpBuilder &builder, OperationState &state,
3318 const TaskOperands &clauses) {
3319 MLIRContext *ctx = builder.getContext();
3320 TaskOp::build(
3321 builder, state, clauses.iterated, clauses.affinityVars,
3322 clauses.allocateVars, clauses.allocatorVars,
3323 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3324 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3325 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3326 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3327 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3328 clauses.priority, /*private_vars=*/clauses.privateVars,
3329 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3330 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3331}
3332
3333LogicalResult TaskOp::verify() {
3334 LogicalResult verifyDependVars =
3335 verifyDependVarList(*this, getDependKinds(), getDependVars(),
3336 getDependIteratedKinds(), getDependIterated());
3337 return failed(verifyDependVars)
3338 ? verifyDependVars
3339 : verifyReductionVarList(*this, getInReductionSyms(),
3340 getInReductionVars(),
3341 getInReductionByref());
3342}
3343
3344//===----------------------------------------------------------------------===//
3345// TaskgroupOp
3346//===----------------------------------------------------------------------===//
3347
3348void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3349 const TaskgroupOperands &clauses) {
3350 MLIRContext *ctx = builder.getContext();
3351 TaskgroupOp::build(builder, state, clauses.allocateVars,
3352 clauses.allocatorVars, clauses.taskReductionVars,
3353 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3354 makeArrayAttr(ctx, clauses.taskReductionSyms));
3355}
3356
3357LogicalResult TaskgroupOp::verify() {
3358 return verifyReductionVarList(*this, getTaskReductionSyms(),
3359 getTaskReductionVars(),
3360 getTaskReductionByref());
3361}
3362
3363//===----------------------------------------------------------------------===//
3364// TaskloopContextOp
3365//===----------------------------------------------------------------------===//
3366
3367void TaskloopContextOp::build(OpBuilder &builder, OperationState &state,
3368 const TaskloopContextOperands &clauses) {
3369 MLIRContext *ctx = builder.getContext();
3370 TaskloopContextOp::build(
3371 builder, state, clauses.allocateVars, clauses.allocatorVars,
3372 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3373 clauses.inReductionVars,
3374 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3375 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3376 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3377 /*private_vars=*/clauses.privateVars,
3378 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3379 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3380 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3381 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3382}
3383
3384TaskloopWrapperOp TaskloopContextOp::getLoopOp() {
3385 return cast<TaskloopWrapperOp>(
3386 *llvm::find_if(getRegion().front(), [](mlir::Operation &op) {
3387 return isa<TaskloopWrapperOp>(op);
3388 }));
3389}
3390
3391LogicalResult TaskloopContextOp::verify() {
3392 if (getAllocateVars().size() != getAllocatorVars().size())
3393 return emitError(
3394 "expected equal sizes for allocate and allocator variables");
3395 if (failed(verifyReductionVarList(*this, getReductionSyms(),
3396 getReductionVars(), getReductionByref())) ||
3397 failed(verifyReductionVarList(*this, getInReductionSyms(),
3398 getInReductionVars(),
3399 getInReductionByref())))
3400 return failure();
3401
3402 if (!getReductionVars().empty() && getNogroup())
3403 return emitError("if a reduction clause is present on the taskloop "
3404 "directive, the nogroup clause must not be specified");
3405 for (auto var : getReductionVars()) {
3406 if (llvm::is_contained(getInReductionVars(), var))
3407 return emitError("the same list item cannot appear in both a reduction "
3408 "and an in_reduction clause");
3409 }
3410
3411 if (getGrainsize() && getNumTasks()) {
3412 return emitError(
3413 "the grainsize clause and num_tasks clause are mutually exclusive and "
3414 "may not appear on the same taskloop directive");
3415 }
3416
3417 return success();
3418}
3419
3420LogicalResult TaskloopContextOp::verifyRegions() {
3421 Region &region = getRegion();
3422 if (region.empty())
3423 return emitOpError() << "expected non-empty region";
3424
3425 auto count = llvm::count_if(region.front(), [](mlir::Operation &op) {
3426 return isa<TaskloopWrapperOp>(op);
3427 });
3428 if (count != 1)
3429 return emitOpError()
3430 << "expected exactly 1 TaskloopWrapperOp directly nested in "
3431 "the region, but "
3432 << count << " were found";
3433 TaskloopWrapperOp loopWrapperOp = getLoopOp();
3434
3435 auto loopNestOp = dyn_cast<LoopNestOp>(loopWrapperOp.getWrappedLoop());
3436 // This will fail the verifier for TaskloopWrapperOp and print an error
3437 // message there.
3438 if (!loopNestOp)
3439 return failure();
3440
3441 std::function<bool(Value)> isValidBoundValue = [&](Value value) -> bool {
3442 Region *valueRegion = value.getParentRegion();
3443 // A loop bound value defined outside of the taskloop context region is
3444 // valid. A region is considered an ancestor of itself.
3445 if (!region.isAncestor(valueRegion))
3446 return true;
3447
3448 Operation *defOp = value.getDefiningOp();
3449 if (!defOp || defOp->getNumRegions() != 0 || !isPure(defOp))
3450 return false;
3451
3452 return llvm::all_of(defOp->getOperands(), isValidBoundValue);
3453 };
3454 auto hasUnsupportedTaskloopLocalBound = [&](OperandRange range) -> bool {
3455 return llvm::any_of(range,
3456 [&](Value value) { return !isValidBoundValue(value); });
3457 };
3458
3459 if (hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopLowerBounds()) ||
3460 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopUpperBounds()) ||
3461 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopSteps())) {
3462 return emitOpError()
3463 << "expects loop bounds and steps to be defined outside of the "
3464 "taskloop.context region or by pure, regionless operations "
3465 "that do not depend on block arguments";
3466 }
3467
3468 return success();
3469}
3470
3471//===----------------------------------------------------------------------===//
3472// TaskloopWrapperOp
3473//===----------------------------------------------------------------------===//
3474
3475void TaskloopWrapperOp::build(OpBuilder &builder, OperationState &state,
3476 const TaskloopWrapperOperands &clauses) {
3477 TaskloopWrapperOp::build(builder, state);
3478}
3479
3480TaskloopContextOp TaskloopWrapperOp::getTaskloopContext() {
3481 return dyn_cast<TaskloopContextOp>(getOperation()->getParentOp());
3482}
3483
3484LogicalResult TaskloopWrapperOp::verify() {
3485 TaskloopContextOp context = getTaskloopContext();
3486 if (!context)
3487 return emitOpError() << "expected to be nested in a taskloop context op";
3488 return success();
3489}
3490
3491LogicalResult TaskloopWrapperOp::verifyRegions() {
3492 if (LoopWrapperInterface nested = getNestedWrapper()) {
3493 if (!isComposite())
3494 return emitError()
3495 << "'omp.composite' attribute missing from composite wrapper";
3496
3497 // Check for the allowed leaf constructs that may appear in a composite
3498 // construct directly after TASKLOOP.
3499 if (!isa<SimdOp>(nested))
3500 return emitError() << "only supported nested wrapper is 'omp.simd'";
3501 } else if (isComposite()) {
3502 return emitError()
3503 << "'omp.composite' attribute present in non-composite wrapper";
3504 }
3505
3506 return success();
3507}
3508
3509//===----------------------------------------------------------------------===//
3510// LoopNestOp
3511//===----------------------------------------------------------------------===//
3512
3513ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3514 // Parse an opening `(` followed by induction variables followed by `)`
3517 Type loopVarType;
3519 parser.parseColonType(loopVarType) ||
3520 // Parse loop bounds.
3521 parser.parseEqual() ||
3522 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3523 parser.parseKeyword("to") ||
3524 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3525 return failure();
3526
3527 for (auto &iv : ivs)
3528 iv.type = loopVarType;
3529
3530 auto *ctx = parser.getBuilder().getContext();
3531 // Parse "inclusive" flag.
3532 if (succeeded(parser.parseOptionalKeyword("inclusive")))
3533 result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3534
3535 // Parse step values.
3537 if (parser.parseKeyword("step") ||
3538 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3539 return failure();
3540
3541 // Parse collapse
3542 int64_t value = 0;
3543 if (!parser.parseOptionalKeyword("collapse") &&
3544 (parser.parseLParen() || parser.parseInteger(value) ||
3545 parser.parseRParen()))
3546 return failure();
3547 if (value > 1)
3548 result.addAttribute(
3549 "collapse_num_loops",
3550 IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3551
3552 // Parse tiles
3554 auto parseTiles = [&]() -> ParseResult {
3555 int64_t tile;
3556 if (parser.parseInteger(tile))
3557 return failure();
3558 tiles.push_back(tile);
3559 return success();
3560 };
3561
3562 if (!parser.parseOptionalKeyword("tiles") &&
3563 (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3564 parser.parseRParen()))
3565 return failure();
3566
3567 if (tiles.size() > 0)
3568 result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3569
3570 // Parse the body.
3571 Region *region = result.addRegion();
3572 if (parser.parseRegion(*region, ivs))
3573 return failure();
3574
3575 // Resolve operands.
3576 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3577 parser.resolveOperands(ubs, loopVarType, result.operands) ||
3578 parser.resolveOperands(steps, loopVarType, result.operands))
3579 return failure();
3580
3581 // Parse the optional attribute list.
3582 return parser.parseOptionalAttrDict(result.attributes);
3583}
3584
3585void LoopNestOp::print(OpAsmPrinter &p) {
3586 Region &region = getRegion();
3587 auto args = region.getArguments();
3588 p << " (" << args << ") : " << args[0].getType() << " = ("
3589 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3590 if (getLoopInclusive())
3591 p << "inclusive ";
3592 p << "step (" << getLoopSteps() << ") ";
3593 if (int64_t numCollapse = getCollapseNumLoops())
3594 if (numCollapse > 1)
3595 p << "collapse(" << numCollapse << ") ";
3596
3597 if (const auto tiles = getTileSizes())
3598 p << "tiles(" << tiles.value() << ") ";
3599
3600 p.printRegion(region, /*printEntryBlockArgs=*/false);
3601}
3602
3603void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3604 const LoopNestOperands &clauses) {
3605 MLIRContext *ctx = builder.getContext();
3606 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3607 clauses.loopLowerBounds, clauses.loopUpperBounds,
3608 clauses.loopSteps, clauses.loopInclusive,
3609 makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3610}
3611
3612LogicalResult LoopNestOp::verify() {
3613 if (getLoopLowerBounds().empty())
3614 return emitOpError() << "must represent at least one loop";
3615
3616 if (getLoopLowerBounds().size() != getIVs().size())
3617 return emitOpError() << "number of range arguments and IVs do not match";
3618
3619 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3620 if (lb.getType() != iv.getType())
3621 return emitOpError()
3622 << "range argument type does not match corresponding IV type";
3623 }
3624
3625 uint64_t numIVs = getIVs().size();
3626
3627 if (const auto &numCollapse = getCollapseNumLoops())
3628 if (numCollapse > numIVs)
3629 return emitOpError()
3630 << "collapse value is larger than the number of loops";
3631
3632 if (const auto &tiles = getTileSizes())
3633 if (tiles.value().size() > numIVs)
3634 return emitOpError() << "too few canonical loops for tile dimensions";
3635
3636 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3637 return emitOpError() << "expects parent op to be a loop wrapper";
3638
3639 return success();
3640}
3641
3642void LoopNestOp::gatherWrappers(
3644 Operation *parent = (*this)->getParentOp();
3645 while (auto wrapper =
3646 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3647 wrappers.push_back(wrapper);
3648 parent = parent->getParentOp();
3649 }
3650}
3651
3652//===----------------------------------------------------------------------===//
3653// OpenMP canonical loop handling
3654//===----------------------------------------------------------------------===//
3655
3656std::tuple<NewCliOp, OpOperand *, OpOperand *>
3657mlir::omp ::decodeCli(Value cli) {
3658
3659 // Defining a CLI for a generated loop is optional; if there is none then
3660 // there is no followup-tranformation
3661 if (!cli)
3662 return {{}, nullptr, nullptr};
3663
3664 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3665 "Unexpected type of cli");
3666
3667 NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3668 OpOperand *gen = nullptr;
3669 OpOperand *cons = nullptr;
3670 for (OpOperand &use : cli.getUses()) {
3671 auto op = cast<LoopTransformationInterface>(use.getOwner());
3672
3673 unsigned opnum = use.getOperandNumber();
3674 if (op.isGeneratee(opnum)) {
3675 assert(!gen && "Each CLI may have at most one def");
3676 gen = &use;
3677 } else if (op.isApplyee(opnum)) {
3678 assert(!cons && "Each CLI may have at most one consumer");
3679 cons = &use;
3680 } else {
3681 llvm_unreachable("Unexpected operand for a CLI");
3682 }
3683 }
3684
3685 return {create, gen, cons};
3686}
3687
3688void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3689 ::mlir::OperationState &odsState) {
3690 odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3691}
3692
3693void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3694 Value result = getResult();
3695 auto [newCli, gen, cons] = decodeCli(result);
3696
3697 // Structured binding `gen` cannot be captured in lambdas before C++20
3698 OpOperand *generator = gen;
3699
3700 // Derive the CLI variable name from its generator:
3701 // * "canonloop" for omp.canonical_loop
3702 // * custom name for loop transformation generatees
3703 // * "cli" as fallback if no generator
3704 // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3705 // at that level
3706 // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3707 // the index of that region
3708 std::string cliName{"cli"};
3709 if (gen) {
3710 cliName =
3712 .Case([&](CanonicalLoopOp op) {
3713 return generateLoopNestingName("canonloop", op);
3714 })
3715 .Case([&](UnrollHeuristicOp op) -> std::string {
3716 llvm_unreachable("heuristic unrolling does not generate a loop");
3717 })
3718 .Case([&](FuseOp op) -> std::string {
3719 unsigned opnum = generator->getOperandNumber();
3720 // The position of the first loop to be fused is the same position
3721 // as the resulting fused loop
3722 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3723 return "canonloop_fuse";
3724 else
3725 return "fused";
3726 })
3727 .Case([&](TileOp op) -> std::string {
3728 auto [generateesFirst, generateesCount] =
3729 op.getGenerateesODSOperandIndexAndLength();
3730 unsigned firstGrid = generateesFirst;
3731 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3732 unsigned end = generateesFirst + generateesCount;
3733 unsigned opnum = generator->getOperandNumber();
3734 // In the OpenMP apply and looprange clauses, indices are 1-based
3735 if (firstGrid <= opnum && opnum < firstIntratile) {
3736 unsigned gridnum = opnum - firstGrid + 1;
3737 return ("grid" + Twine(gridnum)).str();
3738 }
3739 if (firstIntratile <= opnum && opnum < end) {
3740 unsigned intratilenum = opnum - firstIntratile + 1;
3741 return ("intratile" + Twine(intratilenum)).str();
3742 }
3743 llvm_unreachable("Unexpected generatee argument");
3744 })
3745 .DefaultUnreachable("TODO: Custom name for this operation");
3746 }
3747
3748 setNameFn(result, cliName);
3749}
3750
3751LogicalResult NewCliOp::verify() {
3752 Value cli = getResult();
3753
3754 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3755 "Unexpected type of cli");
3756
3757 // Check that the CLI is used in at most generator and one consumer
3758 OpOperand *gen = nullptr;
3759 OpOperand *cons = nullptr;
3760 for (mlir::OpOperand &use : cli.getUses()) {
3761 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
3762
3763 unsigned opnum = use.getOperandNumber();
3764 if (op.isGeneratee(opnum)) {
3765 if (gen) {
3766 InFlightDiagnostic error =
3767 emitOpError("CLI must have at most one generator");
3768 error.attachNote(gen->getOwner()->getLoc())
3769 .append("first generator here:");
3770 error.attachNote(use.getOwner()->getLoc())
3771 .append("second generator here:");
3772 return error;
3773 }
3774
3775 gen = &use;
3776 } else if (op.isApplyee(opnum)) {
3777 if (cons) {
3778 InFlightDiagnostic error =
3779 emitOpError("CLI must have at most one consumer");
3780 error.attachNote(cons->getOwner()->getLoc())
3781 .append("first consumer here:")
3782 .appendOp(*cons->getOwner(),
3783 OpPrintingFlags().printGenericOpForm());
3784 error.attachNote(use.getOwner()->getLoc())
3785 .append("second consumer here:")
3786 .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
3787 return error;
3788 }
3789
3790 cons = &use;
3791 } else {
3792 llvm_unreachable("Unexpected operand for a CLI");
3793 }
3794 }
3795
3796 // If the CLI is source of a transformation, it must have a generator
3797 if (cons && !gen) {
3798 InFlightDiagnostic error = emitOpError("CLI has no generator");
3799 error.attachNote(cons->getOwner()->getLoc())
3800 .append("see consumer here: ")
3801 .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
3802 return error;
3803 }
3804
3805 return success();
3806}
3807
3808void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3809 Value tripCount) {
3810 odsState.addOperands(tripCount);
3811 odsState.addOperands(Value());
3812 (void)odsState.addRegion();
3813}
3814
3815void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
3816 Value tripCount, ::mlir::Value cli) {
3817 odsState.addOperands(tripCount);
3818 odsState.addOperands(cli);
3819 (void)odsState.addRegion();
3820}
3821
3822void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
3823 setNameFn(&getRegion().front(), "body_entry");
3824}
3825
3826void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
3827 OpAsmSetValueNameFn setNameFn) {
3828 std::string ivName = generateLoopNestingName("iv", *this);
3829 setNameFn(region.getArgument(0), ivName);
3830}
3831
3832void CanonicalLoopOp::print(OpAsmPrinter &p) {
3833 if (getCli())
3834 p << '(' << getCli() << ')';
3835 p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
3836 << " in range(" << getTripCount() << ") ";
3837
3838 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
3839 /*printBlockTerminators=*/true);
3840
3841 p.printOptionalAttrDict((*this)->getAttrs());
3842}
3843
3844mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
3846 CanonicalLoopInfoType cliType =
3847 CanonicalLoopInfoType::get(parser.getContext());
3848
3849 // Parse (optional) omp.cli identifier
3851 SmallVector<mlir::Value, 1> cliOperand;
3852 if (!parser.parseOptionalLParen()) {
3853 if (parser.parseOperand(cli) ||
3854 parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
3855 return failure();
3856 }
3857
3858 // We derive the type of tripCount from inductionVariable. MLIR requires the
3859 // type of tripCount to be known when calling resolveOperand so we have parse
3860 // the type before processing the inductionVariable.
3861 OpAsmParser::Argument inductionVariable;
3863 if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
3864 parser.parseKeyword("in") || parser.parseKeyword("range") ||
3865 parser.parseLParen() || parser.parseOperand(tripcount) ||
3866 parser.parseRParen() ||
3867 parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
3868 return failure();
3869
3870 // Parse the loop body.
3871 Region *region = result.addRegion();
3872 if (parser.parseRegion(*region, {inductionVariable}))
3873 return failure();
3874
3875 // We parsed the cli operand forst, but because it is optional, it must be
3876 // last in the operand list.
3877 result.operands.append(cliOperand);
3878
3879 // Parse the optional attribute list.
3880 if (parser.parseOptionalAttrDict(result.attributes))
3881 return failure();
3882
3883 return mlir::success();
3884}
3885
3886LogicalResult CanonicalLoopOp::verify() {
3887 // The region's entry must accept the induction variable
3888 // It can also be empty if just created
3889 if (!getRegion().empty()) {
3890 Region &region = getRegion();
3891 if (region.getNumArguments() != 1)
3892 return emitOpError(
3893 "Canonical loop region must have exactly one argument");
3894
3895 if (getInductionVar().getType() != getTripCount().getType())
3896 return emitOpError(
3897 "Region argument must be the same type as the trip count");
3898 }
3899
3900 return success();
3901}
3902
3903Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
3904
3905std::pair<unsigned, unsigned>
3906CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
3907 // No applyees
3908 return {0, 0};
3909}
3910
3911std::pair<unsigned, unsigned>
3912CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
3913 return getODSOperandIndexAndLength(odsIndex_cli);
3914}
3915
3916//===----------------------------------------------------------------------===//
3917// UnrollHeuristicOp
3918//===----------------------------------------------------------------------===//
3919
3920void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
3921 ::mlir::OperationState &odsState,
3922 ::mlir::Value cli) {
3923 odsState.addOperands(cli);
3924}
3925
3926void UnrollHeuristicOp::print(OpAsmPrinter &p) {
3927 p << '(' << getApplyee() << ')';
3928
3929 p.printOptionalAttrDict((*this)->getAttrs());
3930}
3931
3932mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
3934 auto cliType = CanonicalLoopInfoType::get(parser.getContext());
3935
3936 if (parser.parseLParen())
3937 return failure();
3938
3940 if (parser.parseOperand(applyee) ||
3941 parser.resolveOperand(applyee, cliType, result.operands))
3942 return failure();
3943
3944 if (parser.parseRParen())
3945 return failure();
3946
3947 // Optional output loop (full unrolling has none)
3948 if (!parser.parseOptionalArrow()) {
3949 if (parser.parseLParen() || parser.parseRParen())
3950 return failure();
3951 }
3952
3953 // Parse the optional attribute list.
3954 if (parser.parseOptionalAttrDict(result.attributes))
3955 return failure();
3956
3957 return mlir::success();
3958}
3959
3960std::pair<unsigned, unsigned>
3961UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
3962 return getODSOperandIndexAndLength(odsIndex_applyee);
3963}
3964
3965std::pair<unsigned, unsigned>
3966UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
3967 return {0, 0};
3968}
3969
3970//===----------------------------------------------------------------------===//
3971// TileOp
3972//===----------------------------------------------------------------------===//
3973
3974static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
3975 OperandRange generatees,
3976 OperandRange applyees) {
3977 if (!generatees.empty())
3978 p << '(' << llvm::interleaved(generatees) << ')';
3979
3980 if (!applyees.empty())
3981 p << " <- (" << llvm::interleaved(applyees) << ')';
3982}
3983
3984static ParseResult parseLoopTransformClis(
3985 OpAsmParser &parser,
3988 if (parser.parseOptionalLess()) {
3989 // Syntax 1: generatees present
3990
3991 if (parser.parseOperandList(generateesOperands,
3993 return failure();
3994
3995 if (parser.parseLess())
3996 return failure();
3997 } else {
3998 // Syntax 2: generatees omitted
3999 }
4000
4001 // Parse `<-` (`<` has already been parsed)
4002 if (parser.parseMinus())
4003 return failure();
4004
4005 if (parser.parseOperandList(applyeesOperands,
4007 return failure();
4008
4009 return success();
4010}
4011
4012/// Check properties of the loop nest consisting of the transformation's
4013/// applyees:
4014/// 1. They are nested inside each other
4015/// 2. They are perfectly nested
4016/// (no code with side-effects in-between the loops)
4017/// 3. They are rectangular
4018/// (loop bounds are invariant in respect to the outer loops)
4019///
4020/// TODO: Generalize for LoopTransformationInterface.
4021static LogicalResult checkApplyeesNesting(TileOp op) {
4022 // Collect the loops from the nest
4023 bool isOnlyCanonLoops = true;
4025 for (Value applyee : op.getApplyees()) {
4026 auto [create, gen, cons] = decodeCli(applyee);
4027
4028 if (!gen)
4029 return op.emitOpError() << "applyee CLI has no generator";
4030
4031 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4032 canonLoops.push_back(loop);
4033 if (!loop)
4034 isOnlyCanonLoops = false;
4035 }
4036
4037 // FIXME: We currently can only verify non-rectangularity and perfect nest of
4038 // omp.canonical_loop.
4039 if (!isOnlyCanonLoops)
4040 return success();
4041
4042 DenseSet<Value> parentIVs;
4043 for (auto i : llvm::seq<int>(1, canonLoops.size())) {
4044 auto parentLoop = canonLoops[i - 1];
4045 auto loop = canonLoops[i];
4046
4047 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
4048 return op.emitOpError()
4049 << "tiled loop nest must be nested within each other";
4050
4051 parentIVs.insert(parentLoop.getInductionVar());
4052
4053 // Canonical loop must be perfectly nested, i.e. the body of the parent must
4054 // only contain the omp.canonical_loop of the nested loops, and
4055 // omp.terminator
4056 bool isPerfectlyNested = [&]() {
4057 auto &parentBody = parentLoop.getRegion();
4058 if (!parentBody.hasOneBlock())
4059 return false;
4060 auto &parentBlock = parentBody.getBlocks().front();
4061
4062 auto nestedLoopIt = parentBlock.begin();
4063 if (nestedLoopIt == parentBlock.end() ||
4064 (&*nestedLoopIt != loop.getOperation()))
4065 return false;
4066
4067 auto termIt = std::next(nestedLoopIt);
4068 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
4069 return false;
4070
4071 if (std::next(termIt) != parentBlock.end())
4072 return false;
4073
4074 return true;
4075 }();
4076 if (!isPerfectlyNested)
4077 return op.emitOpError() << "tiled loop nest must be perfectly nested";
4078
4079 if (parentIVs.contains(loop.getTripCount()))
4080 return op.emitOpError() << "tiled loop nest must be rectangular";
4081 }
4082
4083 // TODO: The tile sizes must be computed before the loop, but checking this
4084 // requires dominance analysis. For instance:
4085 //
4086 // %canonloop = omp.new_cli
4087 // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
4088 // // write to %x
4089 // omp.terminator
4090 // }
4091 // %ts = llvm.load %x
4092 // omp.tile <- (%canonloop) sizes(%ts : i32)
4093
4094 return success();
4095}
4096
4097LogicalResult TileOp::verify() {
4098 if (getApplyees().empty())
4099 return emitOpError() << "must apply to at least one loop";
4100
4101 if (getSizes().size() != getApplyees().size())
4102 return emitOpError() << "there must be one tile size for each applyee";
4103
4104 if (!getGeneratees().empty() &&
4105 2 * getSizes().size() != getGeneratees().size())
4106 return emitOpError()
4107 << "expecting two times the number of generatees than applyees";
4108
4109 return checkApplyeesNesting(*this);
4110}
4111
4112std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4113 return getODSOperandIndexAndLength(odsIndex_applyees);
4114}
4115
4116std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4117 return getODSOperandIndexAndLength(odsIndex_generatees);
4118}
4119
4120//===----------------------------------------------------------------------===//
4121// FuseOp
4122//===----------------------------------------------------------------------===//
4123
4124static void printLoopTransformClis(OpAsmPrinter &p, FuseOp op,
4125 OperandRange generatees,
4126 OperandRange applyees) {
4127 if (!generatees.empty())
4128 p << '(' << llvm::interleaved(generatees) << ')';
4129
4130 if (!applyees.empty())
4131 p << " <- (" << llvm::interleaved(applyees) << ')';
4132}
4133
4134LogicalResult FuseOp::verify() {
4135 if (getApplyees().size() < 2)
4136 return emitOpError() << "must apply to at least two loops";
4137
4138 if (getFirst().has_value() && getCount().has_value()) {
4139 int64_t first = getFirst().value();
4140 int64_t count = getCount().value();
4141 if ((unsigned)(first + count - 1) > getApplyees().size())
4142 return emitOpError() << "the numbers of applyees must be at least first "
4143 "minus one plus count attributes";
4144 if (!getGeneratees().empty() &&
4145 getGeneratees().size() != getApplyees().size() + 1 - count)
4146 return emitOpError() << "the number of generatees must be the number of "
4147 "aplyees plus one minus count";
4148
4149 } else {
4150 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4151 return emitOpError()
4152 << "in a complete fuse the number of generatees must be exactly 1";
4153 }
4154 for (auto &&applyee : getApplyees()) {
4155 auto [create, gen, cons] = decodeCli(applyee);
4156
4157 if (!gen)
4158 return emitOpError() << "applyee CLI has no generator";
4159 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4160 if (!loop)
4161 return emitOpError()
4162 << "currently only supports omp.canonical_loop as applyee";
4163 }
4164 return success();
4165}
4166std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4167 return getODSOperandIndexAndLength(odsIndex_applyees);
4168}
4169
4170std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4171 return getODSOperandIndexAndLength(odsIndex_generatees);
4172}
4173
4174//===----------------------------------------------------------------------===//
4175// Critical construct (2.17.1)
4176//===----------------------------------------------------------------------===//
4177
4178void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
4179 const CriticalDeclareOperands &clauses) {
4180 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4181}
4182
4183LogicalResult CriticalDeclareOp::verify() {
4184 return verifySynchronizationHint(*this, getHint());
4185}
4186
4187LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4188 if (getNameAttr()) {
4189 SymbolRefAttr symbolRef = getNameAttr();
4190 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
4191 *this, symbolRef);
4192 if (!decl) {
4193 return emitOpError() << "expected symbol reference " << symbolRef
4194 << " to point to a critical declaration";
4195 }
4196 }
4197
4198 return success();
4199}
4200
4201//===----------------------------------------------------------------------===//
4202// Ordered construct
4203//===----------------------------------------------------------------------===//
4204
4205static LogicalResult verifyOrderedParent(Operation &op) {
4206 bool hasRegion = op.getNumRegions() > 0;
4207 auto loopOp = op.getParentOfType<LoopNestOp>();
4208 if (!loopOp) {
4209 if (hasRegion)
4210 return success();
4211
4212 // TODO: Consider if this needs to be the case only for the standalone
4213 // variant of the ordered construct.
4214 return op.emitOpError() << "must be nested inside of a loop";
4215 }
4216
4217 Operation *wrapper = loopOp->getParentOp();
4218 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4219 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4220 if (!orderedAttr)
4221 return op.emitOpError() << "the enclosing worksharing-loop region must "
4222 "have an ordered clause";
4223
4224 if (hasRegion && orderedAttr.getInt() != 0)
4225 return op.emitOpError() << "the enclosing loop's ordered clause must not "
4226 "have a parameter present";
4227
4228 if (!hasRegion && orderedAttr.getInt() == 0)
4229 return op.emitOpError() << "the enclosing loop's ordered clause must "
4230 "have a parameter present";
4231 } else if (!isa<SimdOp>(wrapper)) {
4232 return op.emitOpError() << "must be nested inside of a worksharing, simd "
4233 "or worksharing simd loop";
4234 }
4235 return success();
4236}
4237
4238void OrderedOp::build(OpBuilder &builder, OperationState &state,
4239 const OrderedOperands &clauses) {
4240 OrderedOp::build(builder, state, clauses.doacrossDependType,
4241 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4242}
4243
4244LogicalResult OrderedOp::verify() {
4245 if (failed(verifyOrderedParent(**this)))
4246 return failure();
4247
4248 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4249 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4250 return emitOpError() << "number of variables in depend clause does not "
4251 << "match number of iteration variables in the "
4252 << "doacross loop";
4253
4254 return success();
4255}
4256
4257void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
4258 const OrderedRegionOperands &clauses) {
4259 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4260}
4261
4262LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
4263
4264//===----------------------------------------------------------------------===//
4265// TaskwaitOp
4266//===----------------------------------------------------------------------===//
4267
4268void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
4269 const TaskwaitOperands &clauses) {
4270 // TODO Store clauses in op: dependKinds, dependVars, nowait.
4271 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
4272 /*depend_vars=*/{}, /*depend_iterated_kinds=*/nullptr,
4273 /*depend_iterated=*/{}, /*nowait=*/nullptr);
4274}
4275
4276//===----------------------------------------------------------------------===//
4277// Verifier for AtomicReadOp
4278//===----------------------------------------------------------------------===//
4279
4280LogicalResult AtomicReadOp::verify() {
4281 if (verifyCommon().failed())
4282 return mlir::failure();
4283
4284 if (auto mo = getMemoryOrder()) {
4285 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4286 *mo == ClauseMemoryOrderKind::Release) {
4287 return emitError(
4288 "memory-order must not be acq_rel or release for atomic reads");
4289 }
4290 }
4291 return verifySynchronizationHint(*this, getHint());
4292}
4293
4294//===----------------------------------------------------------------------===//
4295// Verifier for AtomicWriteOp
4296//===----------------------------------------------------------------------===//
4297
4298LogicalResult AtomicWriteOp::verify() {
4299 if (verifyCommon().failed())
4300 return mlir::failure();
4301
4302 if (auto mo = getMemoryOrder()) {
4303 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4304 *mo == ClauseMemoryOrderKind::Acquire) {
4305 return emitError(
4306 "memory-order must not be acq_rel or acquire for atomic writes");
4307 }
4308 }
4309 return verifySynchronizationHint(*this, getHint());
4310}
4311
4312//===----------------------------------------------------------------------===//
4313// Verifier for AtomicUpdateOp
4314//===----------------------------------------------------------------------===//
4315
4316LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4317 PatternRewriter &rewriter) {
4318 if (op.isNoOp()) {
4319 rewriter.eraseOp(op);
4320 return success();
4321 }
4322 if (Value writeVal = op.getWriteOpVal()) {
4323 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
4324 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4325 return success();
4326 }
4327 return failure();
4328}
4329
4330LogicalResult AtomicUpdateOp::verify() {
4331 if (verifyCommon().failed())
4332 return mlir::failure();
4333
4334 if (auto mo = getMemoryOrder()) {
4335 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4336 *mo == ClauseMemoryOrderKind::Acquire) {
4337 return emitError(
4338 "memory-order must not be acq_rel or acquire for atomic updates");
4339 }
4340 }
4341
4342 return verifySynchronizationHint(*this, getHint());
4343}
4344
4345LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4346
4347//===----------------------------------------------------------------------===//
4348// Verifier for AtomicCaptureOp
4349//===----------------------------------------------------------------------===//
4350
4351AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4352 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4353 return op;
4354 return dyn_cast<AtomicReadOp>(getSecondOp());
4355}
4356
4357AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4358 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4359 return op;
4360 return dyn_cast<AtomicWriteOp>(getSecondOp());
4361}
4362
4363AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4364 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4365 return op;
4366 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4367}
4368
4369LogicalResult AtomicCaptureOp::verify() {
4370 return verifySynchronizationHint(*this, getHint());
4371}
4372
4373LogicalResult AtomicCaptureOp::verifyRegions() {
4374 if (verifyRegionsCommon().failed())
4375 return mlir::failure();
4376
4377 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
4378 return emitOpError(
4379 "operations inside capture region must not have hint clause");
4380
4381 if (getFirstOp()->getAttr("memory_order") ||
4382 getSecondOp()->getAttr("memory_order"))
4383 return emitOpError(
4384 "operations inside capture region must not have memory_order clause");
4385 return success();
4386}
4387
4388//===----------------------------------------------------------------------===//
4389// CancelOp
4390//===----------------------------------------------------------------------===//
4391
4392void CancelOp::build(OpBuilder &builder, OperationState &state,
4393 const CancelOperands &clauses) {
4394 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4395}
4396
4398 Operation *parent = thisOp->getParentOp();
4399 while (parent) {
4400 if (parent->getDialect() == thisOp->getDialect())
4401 return parent;
4402 parent = parent->getParentOp();
4403 }
4404 return nullptr;
4405}
4406
4407LogicalResult CancelOp::verify() {
4408 ClauseCancellationConstructType cct = getCancelDirective();
4409 // The next OpenMP operation in the chain of parents
4410 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4411 if (!structuralParent)
4412 return emitOpError() << "Orphaned cancel construct";
4413
4414 if ((cct == ClauseCancellationConstructType::Parallel) &&
4415 !mlir::isa<ParallelOp>(structuralParent)) {
4416 return emitOpError() << "cancel parallel must appear "
4417 << "inside a parallel region";
4418 }
4419 if (cct == ClauseCancellationConstructType::Loop) {
4420 // structural parent will be omp.loop_nest, directly nested inside
4421 // omp.wsloop
4422 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4423
4424 if (!wsloopOp) {
4425 return emitOpError()
4426 << "cancel loop must appear inside a worksharing-loop region";
4427 }
4428 if (wsloopOp.getNowaitAttr()) {
4429 return emitError() << "A worksharing construct that is canceled "
4430 << "must not have a nowait clause";
4431 }
4432 if (wsloopOp.getOrderedAttr()) {
4433 return emitError() << "A worksharing construct that is canceled "
4434 << "must not have an ordered clause";
4435 }
4436
4437 } else if (cct == ClauseCancellationConstructType::Sections) {
4438 // structural parent will be an omp.section, directly nested inside
4439 // omp.sections
4440 auto sectionsOp =
4441 mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4442 if (!sectionsOp) {
4443 return emitOpError() << "cancel sections must appear "
4444 << "inside a sections region";
4445 }
4446 if (sectionsOp.getNowait()) {
4447 return emitError() << "A sections construct that is canceled "
4448 << "must not have a nowait clause";
4449 }
4450 }
4451 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4452 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4453 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->getParentOp()))) {
4454 return emitOpError() << "cancel taskgroup must appear "
4455 << "inside a task region";
4456 }
4457 return success();
4458}
4459
4460//===----------------------------------------------------------------------===//
4461// CancellationPointOp
4462//===----------------------------------------------------------------------===//
4463
4464void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4465 const CancellationPointOperands &clauses) {
4466 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4467}
4468
4469LogicalResult CancellationPointOp::verify() {
4470 ClauseCancellationConstructType cct = getCancelDirective();
4471 // The next OpenMP operation in the chain of parents
4472 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4473 if (!structuralParent)
4474 return emitOpError() << "Orphaned cancellation point";
4475
4476 if ((cct == ClauseCancellationConstructType::Parallel) &&
4477 !mlir::isa<ParallelOp>(structuralParent)) {
4478 return emitOpError() << "cancellation point parallel must appear "
4479 << "inside a parallel region";
4480 }
4481 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4482 // find the wsloop
4483 if ((cct == ClauseCancellationConstructType::Loop) &&
4484 !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4485 return emitOpError() << "cancellation point loop must appear "
4486 << "inside a worksharing-loop region";
4487 }
4488 if ((cct == ClauseCancellationConstructType::Sections) &&
4489 !mlir::isa<omp::SectionOp>(structuralParent)) {
4490 return emitOpError() << "cancellation point sections must appear "
4491 << "inside a sections region";
4492 }
4493 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4494 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4495 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->getParentOp()))) {
4496 return emitOpError() << "cancellation point taskgroup must appear "
4497 << "inside a task region";
4498 }
4499 return success();
4500}
4501
4502//===----------------------------------------------------------------------===//
4503// MapBoundsOp
4504//===----------------------------------------------------------------------===//
4505
4506LogicalResult MapBoundsOp::verify() {
4507 auto extent = getExtent();
4508 auto upperbound = getUpperBound();
4509 if (!extent && !upperbound)
4510 return emitError("expected extent or upperbound.");
4511 return success();
4512}
4513
4514void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4515 TypeRange /*result_types*/, StringAttr symName,
4516 TypeAttr type) {
4517 PrivateClauseOp::build(
4518 odsBuilder, odsState, symName, type,
4519 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
4520 DataSharingClauseType::Private));
4521}
4522
4523LogicalResult PrivateClauseOp::verifyRegions() {
4524 Type argType = getArgType();
4525 auto verifyTerminator = [&](Operation *terminator,
4526 bool yieldsValue) -> LogicalResult {
4527 if (!terminator->getBlock()->getSuccessors().empty())
4528 return success();
4529
4530 if (!llvm::isa<YieldOp>(terminator))
4531 return mlir::emitError(terminator->getLoc())
4532 << "expected exit block terminator to be an `omp.yield` op.";
4533
4534 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4535 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4536
4537 if (!yieldsValue) {
4538 if (yieldedTypes.empty())
4539 return success();
4540
4541 return mlir::emitError(terminator->getLoc())
4542 << "Did not expect any values to be yielded.";
4543 }
4544
4545 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4546 return success();
4547
4548 auto error = mlir::emitError(yieldOp.getLoc())
4549 << "Invalid yielded value. Expected type: " << argType
4550 << ", got: ";
4551
4552 if (yieldedTypes.empty())
4553 error << "None";
4554 else
4555 error << yieldedTypes;
4556
4557 return error;
4558 };
4559
4560 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4561 StringRef regionName,
4562 bool yieldsValue) -> LogicalResult {
4563 assert(!region.empty());
4564
4565 if (region.getNumArguments() != expectedNumArgs)
4566 return mlir::emitError(region.getLoc())
4567 << "`" << regionName << "`: "
4568 << "expected " << expectedNumArgs
4569 << " region arguments, got: " << region.getNumArguments();
4570
4571 for (Block &block : region) {
4572 // MLIR will verify the absence of the terminator for us.
4573 if (!block.mightHaveTerminator())
4574 continue;
4575
4576 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4577 return failure();
4578 }
4579
4580 return success();
4581 };
4582
4583 // Ensure all of the region arguments have the same type
4584 for (Region *region : getRegions())
4585 for (Type ty : region->getArgumentTypes())
4586 if (ty != argType)
4587 return emitError() << "Region argument type mismatch: got " << ty
4588 << " expected " << argType << ".";
4589
4590 mlir::Region &initRegion = getInitRegion();
4591 if (!initRegion.empty() &&
4592 failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4593 /*yieldsValue=*/true)))
4594 return failure();
4595
4596 DataSharingClauseType dsType = getDataSharingType();
4597
4598 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4599 return emitError("`private` clauses do not require a `copy` region.");
4600
4601 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4602 return emitError(
4603 "`firstprivate` clauses require at least a `copy` region.");
4604
4605 if (dsType == DataSharingClauseType::FirstPrivate &&
4606 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4607 /*yieldsValue=*/true)))
4608 return failure();
4609
4610 if (!getDeallocRegion().empty() &&
4611 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4612 /*yieldsValue=*/false)))
4613 return failure();
4614
4615 return success();
4616}
4617
4618//===----------------------------------------------------------------------===//
4619// Spec 5.2: Masked construct (10.5)
4620//===----------------------------------------------------------------------===//
4621
4622void MaskedOp::build(OpBuilder &builder, OperationState &state,
4623 const MaskedOperands &clauses) {
4624 MaskedOp::build(builder, state, clauses.filteredThreadId);
4625}
4626
4627//===----------------------------------------------------------------------===//
4628// Spec 5.2: Scan construct (5.6)
4629//===----------------------------------------------------------------------===//
4630
4631void ScanOp::build(OpBuilder &builder, OperationState &state,
4632 const ScanOperands &clauses) {
4633 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4634}
4635
4636LogicalResult ScanOp::verify() {
4637 if (hasExclusiveVars() == hasInclusiveVars())
4638 return emitError(
4639 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4640 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4641 if (parentWsLoopOp.getReductionModAttr() &&
4642 parentWsLoopOp.getReductionModAttr().getValue() ==
4643 ReductionModifier::inscan)
4644 return success();
4645 }
4646 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4647 if (parentSimdOp.getReductionModAttr() &&
4648 parentSimdOp.getReductionModAttr().getValue() ==
4649 ReductionModifier::inscan)
4650 return success();
4651 }
4652 return emitError("SCAN directive needs to be enclosed within a parent "
4653 "worksharing loop construct or SIMD construct with INSCAN "
4654 "reduction modifier");
4655}
4656
4657/// Verifies align clause in allocate directive
4658
4659LogicalResult AllocateDirOp::verify() {
4660 std::optional<uint64_t> align = this->getAlign();
4661
4662 if (align.has_value()) {
4663 if ((align.value() > 0) && !llvm::has_single_bit(align.value()))
4664 return emitError() << "ALIGN value : " << align.value()
4665 << " must be power of 2";
4666 }
4667
4668 return success();
4669}
4670
4671//===----------------------------------------------------------------------===//
4672// TargetAllocMemOp
4673//===----------------------------------------------------------------------===//
4674
4675mlir::Type omp::TargetAllocMemOp::getAllocatedType() {
4676 return getInTypeAttr().getValue();
4677}
4678
4679/// operation ::= %res = (`omp.target_alloc_mem`) $device : devicetype,
4680/// $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
4681/// attr-dict-without-keyword
4682static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser,
4684 auto &builder = parser.getBuilder();
4685 bool hasOperands = false;
4686 std::int32_t typeparamsSize = 0;
4687
4688 // Parse device number as a new operand
4690 mlir::Type deviceType;
4691 if (parser.parseOperand(deviceOperand) || parser.parseColonType(deviceType))
4692 return mlir::failure();
4693 if (parser.resolveOperand(deviceOperand, deviceType, result.operands))
4694 return mlir::failure();
4695 if (parser.parseComma())
4696 return mlir::failure();
4697
4698 mlir::Type intype;
4699 if (parser.parseType(intype))
4700 return mlir::failure();
4701 result.addAttribute("in_type", mlir::TypeAttr::get(intype));
4704 if (!parser.parseOptionalLParen()) {
4705 // parse the LEN params of the derived type. (<params> : <types>)
4707 parser.parseColonTypeList(typeVec) || parser.parseRParen())
4708 return mlir::failure();
4709 typeparamsSize = operands.size();
4710 hasOperands = true;
4711 }
4712 std::int32_t shapeSize = 0;
4713 if (!parser.parseOptionalComma()) {
4714 // parse size to scale by, vector of n dimensions of type index
4716 return mlir::failure();
4717 shapeSize = operands.size() - typeparamsSize;
4718 auto idxTy = builder.getIndexType();
4719 for (std::int32_t i = typeparamsSize, end = operands.size(); i != end; ++i)
4720 typeVec.push_back(idxTy);
4721 hasOperands = true;
4722 }
4723 if (hasOperands &&
4724 parser.resolveOperands(operands, typeVec, parser.getNameLoc(),
4725 result.operands))
4726 return mlir::failure();
4727
4728 mlir::Type restype = builder.getIntegerType(64);
4729 if (!restype) {
4730 parser.emitError(parser.getNameLoc(), "invalid allocate type: ") << intype;
4731 return mlir::failure();
4732 }
4733 llvm::SmallVector<std::int32_t> segmentSizes{1, typeparamsSize, shapeSize};
4734 result.addAttribute("operandSegmentSizes",
4735 builder.getDenseI32ArrayAttr(segmentSizes));
4736 if (parser.parseOptionalAttrDict(result.attributes) ||
4737 parser.addTypeToList(restype, result.types))
4738 return mlir::failure();
4739 return mlir::success();
4740}
4741
4742mlir::ParseResult omp::TargetAllocMemOp::parse(mlir::OpAsmParser &parser,
4744 return parseTargetAllocMemOp(parser, result);
4745}
4746
4747void omp::TargetAllocMemOp::print(mlir::OpAsmPrinter &p) {
4748 p << " ";
4750 p << " : ";
4751 p << getDevice().getType();
4752 p << ", ";
4753 p << getInType();
4754 if (!getTypeparams().empty()) {
4755 p << '(' << getTypeparams() << " : " << getTypeparams().getTypes() << ')';
4756 }
4757 for (auto sh : getShape()) {
4758 p << ", ";
4759 p.printOperand(sh);
4760 }
4761 p.printOptionalAttrDict((*this)->getAttrs(),
4762 {"in_type", "operandSegmentSizes"});
4763}
4764
4765llvm::LogicalResult omp::TargetAllocMemOp::verify() {
4766 mlir::Type outType = getType();
4767 if (!mlir::dyn_cast<IntegerType>(outType))
4768 return emitOpError("must be a integer type");
4769 return mlir::success();
4770}
4771
4772//===----------------------------------------------------------------------===//
4773// WorkdistributeOp
4774//===----------------------------------------------------------------------===//
4775
4776LogicalResult WorkdistributeOp::verify() {
4777 // Check that region exists and is not empty
4778 Region &region = getRegion();
4779 if (region.empty())
4780 return emitOpError("region cannot be empty");
4781 // Verify single entry point.
4782 Block &entryBlock = region.front();
4783 if (entryBlock.empty())
4784 return emitOpError("region must contain a structured block");
4785 // Verify single exit point.
4786 bool hasTerminator = false;
4787 for (Block &block : region) {
4788 if (isa<TerminatorOp>(block.back())) {
4789 if (hasTerminator) {
4790 return emitOpError("region must have exactly one terminator");
4791 }
4792 hasTerminator = true;
4793 }
4794 }
4795 if (!hasTerminator) {
4796 return emitOpError("region must be terminated with omp.terminator");
4797 }
4798 auto walkResult = region.walk([&](Operation *op) -> WalkResult {
4799 // No implicit barrier at end
4800 if (isa<BarrierOp>(op)) {
4801 return emitOpError(
4802 "explicit barriers are not allowed in workdistribute region");
4803 }
4804 // Check for invalid nested constructs
4805 if (isa<ParallelOp>(op)) {
4806 return emitOpError(
4807 "nested parallel constructs not allowed in workdistribute");
4808 }
4809 if (isa<TeamsOp>(op)) {
4810 return emitOpError(
4811 "nested teams constructs not allowed in workdistribute");
4812 }
4813 return WalkResult::advance();
4814 });
4815 if (walkResult.wasInterrupted())
4816 return failure();
4817
4818 Operation *parentOp = (*this)->getParentOp();
4819 if (!llvm::dyn_cast<TeamsOp>(parentOp))
4820 return emitOpError("workdistribute must be nested under teams");
4821 return success();
4822}
4823
4824//===----------------------------------------------------------------------===//
4825// Declare simd [7.7]
4826//===----------------------------------------------------------------------===//
4827
4828LogicalResult DeclareSimdOp::verify() {
4829 // Must be nested inside a function-like op
4830 auto func =
4831 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
4832 if (!func)
4833 return emitOpError() << "must be nested inside a function";
4834
4835 if (getInbranch() && getNotinbranch())
4836 return emitOpError("cannot have both 'inbranch' and 'notinbranch'");
4837
4838 if (failed(verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars(),
4839 /*isDeclareSimd=*/true)))
4840 return failure();
4841
4842 return verifyAlignedClause(*this, getAlignments(), getAlignedVars());
4843}
4844
4845void DeclareSimdOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4846 const DeclareSimdOperands &clauses) {
4847 MLIRContext *ctx = odsBuilder.getContext();
4848 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
4849 makeArrayAttr(ctx, clauses.alignments), clauses.inbranch,
4850 clauses.linearVars, clauses.linearStepVars,
4851 clauses.linearVarTypes, clauses.linearModifiers,
4852 clauses.notinbranch, clauses.simdlen,
4853 clauses.uniformVars);
4854}
4855
4856//===----------------------------------------------------------------------===//
4857// Parser and printer for Uniform Clause
4858//===----------------------------------------------------------------------===//
4859
4860/// uniform ::= `uniform` `(` uniform-list `)`
4861/// uniform-list := uniform-val (`,` uniform-val)*
4862/// uniform-val := ssa-id `:` type
4863static ParseResult
4866 SmallVectorImpl<Type> &uniformTypes) {
4867 return parser.parseCommaSeparatedList([&]() -> mlir::ParseResult {
4868 if (parser.parseOperand(uniformVars.emplace_back()) ||
4869 parser.parseColonType(uniformTypes.emplace_back()))
4870 return mlir::failure();
4871 return mlir::success();
4872 });
4873}
4874
4875/// Print Uniform Clauses
4877 ValueRange uniformVars, TypeRange uniformTypes) {
4878 for (unsigned i = 0; i < uniformVars.size(); ++i) {
4879 if (i != 0)
4880 p << ", ";
4881 p << uniformVars[i] << " : " << uniformTypes[i];
4882 }
4883}
4884
4885//===----------------------------------------------------------------------===//
4886// Parser and printer for Affinity Clause
4887//===----------------------------------------------------------------------===//
4888
4889static ParseResult parseAffinityClause(
4890 OpAsmParser &parser,
4893 SmallVectorImpl<Type> &iteratedTypes,
4894 SmallVectorImpl<Type> &affinityVarTypes) {
4895 if (failed(parseSplitIteratedList(
4896 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
4897 /*parsePrefix=*/[&]() -> ParseResult { return success(); })))
4898 return failure();
4899 return success();
4900}
4901
4903 ValueRange iterated, ValueRange affinityVars,
4904 TypeRange iteratedTypes,
4905 TypeRange affinityVarTypes) {
4906 auto nop = [&](Value, Type) {};
4907 printSplitIteratedList(p, iterated, iteratedTypes, affinityVars,
4908 affinityVarTypes,
4909 /*plain prefix*/ nop,
4910 /*iterated prefix*/ nop);
4911}
4912
4913//===----------------------------------------------------------------------===//
4914// Parser, printer, and verifier for Iterator modifier
4915//===----------------------------------------------------------------------===//
4916
4917static ParseResult
4922 SmallVectorImpl<Type> &lbTypes,
4923 SmallVectorImpl<Type> &ubTypes,
4924 SmallVectorImpl<Type> &stepTypes) {
4925
4926 llvm::SMLoc ivLoc = parser.getCurrentLocation();
4928
4929 // Parse induction variables: %i : i32, %j : i32
4930 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
4931 OpAsmParser::Argument &arg = ivArgs.emplace_back();
4932 if (parser.parseArgument(arg))
4933 return failure();
4934
4935 // Optional type, default to Index if not provided
4936 if (succeeded(parser.parseOptionalColon())) {
4937 if (parser.parseType(arg.type))
4938 return failure();
4939 } else {
4940 arg.type = parser.getBuilder().getIndexType();
4941 }
4942 return success();
4943 }))
4944 return failure();
4945
4946 // ) = (
4947 if (parser.parseRParen() || parser.parseEqual() || parser.parseLParen())
4948 return failure();
4949
4950 // Parse Ranges: (%lb to %ub step %st, ...)
4951 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
4952 OpAsmParser::UnresolvedOperand lb, ub, st;
4953 if (parser.parseOperand(lb) || parser.parseKeyword("to") ||
4954 parser.parseOperand(ub) || parser.parseKeyword("step") ||
4955 parser.parseOperand(st))
4956 return failure();
4957
4958 lbs.push_back(lb);
4959 ubs.push_back(ub);
4960 steps.push_back(st);
4961 return success();
4962 }))
4963 return failure();
4964
4965 if (parser.parseRParen())
4966 return failure();
4967
4968 if (ivArgs.size() != lbs.size())
4969 return parser.emitError(ivLoc)
4970 << "mismatch: " << ivArgs.size() << " variables but " << lbs.size()
4971 << " ranges";
4972
4973 for (auto &arg : ivArgs) {
4974 lbTypes.push_back(arg.type);
4975 ubTypes.push_back(arg.type);
4976 stepTypes.push_back(arg.type);
4977 }
4978
4979 return parser.parseRegion(region, ivArgs);
4980}
4981
4983 ValueRange lbs, ValueRange ubs,
4985 TypeRange) {
4986 Block &entry = region.front();
4987
4988 for (unsigned i = 0, e = entry.getNumArguments(); i < e; ++i) {
4989 if (i != 0)
4990 p << ", ";
4991 p.printRegionArgument(entry.getArgument(i));
4992 }
4993 p << ") = (";
4994
4995 // (%lb0 to %ub0 step %step0, %lb1 to %ub1 step %step1, ...)
4996 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
4997 if (i)
4998 p << ", ";
4999 p << lbs[i] << " to " << ubs[i] << " step " << steps[i];
5000 }
5001 p << ") ";
5002
5003 p.printRegion(region, /*printEntryBlockArgs=*/false,
5004 /*printBlockTerminators=*/true);
5005}
5006
5007LogicalResult IteratorOp::verify() {
5008 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().getType());
5009 if (!iteratedTy)
5010 return emitOpError() << "result must be omp.iterated<entry_ty>";
5011
5012 for (auto [lb, ub, step] : llvm::zip_equal(
5013 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
5014 if (matchPattern(step, m_Zero()))
5015 return emitOpError() << "loop step must not be zero";
5016
5017 IntegerAttr lbAttr;
5018 IntegerAttr ubAttr;
5019 IntegerAttr stepAttr;
5020 if (!matchPattern(lb, m_Constant(&lbAttr)) ||
5021 !matchPattern(ub, m_Constant(&ubAttr)) ||
5022 !matchPattern(step, m_Constant(&stepAttr)))
5023 continue;
5024
5025 const APInt &lbVal = lbAttr.getValue();
5026 const APInt &ubVal = ubAttr.getValue();
5027 const APInt &stepVal = stepAttr.getValue();
5028 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
5029 return emitOpError() << "positive loop step requires lower bound to be "
5030 "less than or equal to upper bound";
5031 if (stepVal.isNegative() && lbVal.slt(ubVal))
5032 return emitOpError() << "negative loop step requires lower bound to be "
5033 "greater than or equal to upper bound";
5034 }
5035
5036 Block &b = getRegion().front();
5037 auto yield = llvm::dyn_cast<omp::YieldOp>(b.getTerminator());
5038
5039 if (!yield)
5040 return emitOpError() << "region must be terminated by omp.yield";
5041
5042 if (yield.getNumOperands() != 1)
5043 return emitOpError()
5044 << "omp.yield in omp.iterator region must yield exactly one value";
5045
5046 mlir::Type yieldedTy = yield.getOperand(0).getType();
5047 mlir::Type elemTy = iteratedTy.getElementType();
5048
5049 if (yieldedTy != elemTy)
5050 return emitOpError() << "omp.iterated element type (" << elemTy
5051 << ") does not match omp.yield operand type ("
5052 << yieldedTy << ")";
5053
5054 return success();
5055}
5056
5057#define GET_ATTRDEF_CLASSES
5058#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
5059
5060#define GET_OP_CLASSES
5061#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
5062
5063#define GET_TYPEDEF_CLASSES
5064#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1497
static Type getElementType(Type type)
Determine the element type of type.
static ze_device_handle_t getDevice(const uint32_t driverIdx=0, const int32_t devIdx=0)
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static mlir::ParseResult parseTargetAllocMemOp(mlir::OpAsmParser &parser, mlir::OperationState &result)
operation ::= res = (omp.target_alloc_mem) $device : devicetype, $in_type ( ( $typeparams ) )?...
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
return success()
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static ArrayRef< int64_t > getShape(Type type)
Returns the shape of the given type.
Definition Traits.cpp:117
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
ParseResult addTypeToList(Type type, SmallVectorImpl< Type > &result)
Add the specified type to the end of the specified type list and return success.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
SuccessorRange getSuccessors()
Definition Block.h:280
BlockArgListType getArguments()
Definition Block.h:97
DenseI32ArrayAttr getDenseI32ArrayAttr(ArrayRef< int32_t > values)
Definition Builders.cpp:167
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
IndexType getIndexType()
Definition Builders.cpp:55
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:140
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:158
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:88
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:238
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:712
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:775
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:231
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:700
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:241
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:252
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:256
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:703
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:404
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:823
user_range getUsers()
Returns a range of all users.
Definition Operation.h:899
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:248
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:234
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition Region.h:124
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.