MLIR 23.0.0git
OpenMPDialect.cpp
Go to the documentation of this file.
1//===- OpenMPDialect.cpp - MLIR Dialect for OpenMP implementation ---------===//
2//
3// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6//
7//===----------------------------------------------------------------------===//
8//
9// This file implements the OpenMP dialect and its operations.
10//
11//===----------------------------------------------------------------------===//
12
18#include "mlir/IR/Attributes.h"
21#include "mlir/IR/Matchers.h"
24#include "mlir/IR/SymbolTable.h"
27
28#include "llvm/ADT/ArrayRef.h"
29#include "llvm/ADT/PostOrderIterator.h"
30#include "llvm/ADT/STLExtras.h"
31#include "llvm/ADT/STLForwardCompat.h"
32#include "llvm/ADT/SmallString.h"
33#include "llvm/ADT/StringExtras.h"
34#include "llvm/ADT/StringRef.h"
35#include "llvm/ADT/TypeSwitch.h"
36#include "llvm/ADT/bit.h"
37#include "llvm/Support/InterleavedRange.h"
38#include <cstddef>
39#include <iterator>
40#include <optional>
41#include <variant>
42
43#include "mlir/Dialect/OpenMP/OpenMPOpsDialect.cpp.inc"
44#include "mlir/Dialect/OpenMP/OpenMPOpsEnums.cpp.inc"
45#include "mlir/Dialect/OpenMP/OpenMPOpsInterfaces.cpp.inc"
46#include "mlir/Dialect/OpenMP/OpenMPTypeInterfaces.cpp.inc"
47
48using namespace mlir;
49using namespace mlir::omp;
50
53 return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
54}
55
58 return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
59}
60
63 return intArray.empty() ? nullptr : DenseI64ArrayAttr::get(ctx, intArray);
64}
65
66namespace {
67struct MemRefPointerLikeModel
68 : public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
69 MemRefType> {
70 Type getElementType(Type pointer) const {
71 return llvm::cast<MemRefType>(pointer).getElementType();
72 }
73};
74
75struct LLVMPointerPointerLikeModel
76 : public PointerLikeType::ExternalModel<LLVMPointerPointerLikeModel,
77 LLVM::LLVMPointerType> {
78 Type getElementType(Type pointer) const { return Type(); }
79};
80} // namespace
81
82/// Generate a name of a canonical loop nest of the format
83/// `<prefix>(_r<idx>_s<idx>)*`. Hereby, `_r<idx>` identifies the region
84/// argument index of an operation that has multiple regions, if the operation
85/// has multiple regions.
86/// `_s<idx>` identifies the position of an operation within a region, where
87/// only operations that may potentially contain loops ("container operations"
88/// i.e. have region arguments) are counted. Again, it is omitted if there is
89/// only one such operation in a region. If there are canonical loops nested
90/// inside each other, also may also use the format `_d<num>` where <num> is the
91/// nesting depth of the loop.
92///
93/// The generated name is a best-effort to make canonical loop unique within an
94/// SSA namespace. This also means that regions with IsolatedFromAbove property
95/// do not consider any parents or siblings.
96static std::string generateLoopNestingName(StringRef prefix,
97 CanonicalLoopOp op) {
98 struct Component {
99 /// If true, this component describes a region operand of an operation (the
100 /// operand's owner) If false, this component describes an operation located
101 /// in a parent region
102 bool isRegionArgOfOp;
103 bool skip = false;
104 bool isUnique = false;
105
106 size_t idx;
107 Operation *op;
108 Region *parentRegion;
109 size_t loopDepth;
110
111 Operation *&getOwnerOp() {
112 assert(isRegionArgOfOp && "Must describe a region operand");
113 return op;
114 }
115 size_t &getArgIdx() {
116 assert(isRegionArgOfOp && "Must describe a region operand");
117 return idx;
118 }
119
120 Operation *&getContainerOp() {
121 assert(!isRegionArgOfOp && "Must describe a operation of a region");
122 return op;
123 }
124 size_t &getOpPos() {
125 assert(!isRegionArgOfOp && "Must describe a operation of a region");
126 return idx;
127 }
128 bool isLoopOp() const {
129 assert(!isRegionArgOfOp && "Must describe a operation of a region");
130 return isa<CanonicalLoopOp>(op);
131 }
132 Region *&getParentRegion() {
133 assert(!isRegionArgOfOp && "Must describe a operation of a region");
134 return parentRegion;
135 }
136 size_t &getLoopDepth() {
137 assert(!isRegionArgOfOp && "Must describe a operation of a region");
138 return loopDepth;
139 }
140
141 void skipIf(bool v = true) { skip = skip || v; }
142 };
143
144 // List of ancestors, from inner to outer.
145 // Alternates between
146 // * region argument of an operation
147 // * operation within a region
148 SmallVector<Component> components;
149
150 // Gather a list of parent regions and operations, and the position within
151 // their parent
152 Operation *o = op.getOperation();
153 while (o) {
154 // Operation within a region
155 Region *r = o->getParentRegion();
156 if (!r)
157 break;
158
159 llvm::ReversePostOrderTraversal<Block *> traversal(&r->getBlocks().front());
160 size_t idx = 0;
161 bool found = false;
162 size_t sequentialIdx = -1;
163 bool isOnlyContainerOp = true;
164 for (Block *b : traversal) {
165 for (Operation &op : *b) {
166 if (&op == o && !found) {
167 sequentialIdx = idx;
168 found = true;
169 }
170 if (op.getNumRegions()) {
171 idx += 1;
172 if (idx > 1)
173 isOnlyContainerOp = false;
174 }
175 if (found && !isOnlyContainerOp)
176 break;
177 }
178 }
179
180 Component &containerOpInRegion = components.emplace_back();
181 containerOpInRegion.isRegionArgOfOp = false;
182 containerOpInRegion.isUnique = isOnlyContainerOp;
183 containerOpInRegion.getContainerOp() = o;
184 containerOpInRegion.getOpPos() = sequentialIdx;
185 containerOpInRegion.getParentRegion() = r;
186
187 Operation *parent = r->getParentOp();
188
189 // Region argument of an operation
190 Component &regionArgOfOperation = components.emplace_back();
191 regionArgOfOperation.isRegionArgOfOp = true;
192 regionArgOfOperation.isUnique = true;
193 regionArgOfOperation.getArgIdx() = 0;
194 regionArgOfOperation.getOwnerOp() = parent;
195
196 // The IsolatedFromAbove trait of the parent operation implies that each
197 // individual region argument has its own separate namespace, so no
198 // ambiguity.
199 if (!parent || parent->hasTrait<mlir::OpTrait::IsIsolatedFromAbove>())
200 break;
201
202 // Component only needed if operation has multiple region operands. Region
203 // arguments may be optional, but we currently do not consider this.
204 if (parent->getRegions().size() > 1) {
205 auto getRegionIndex = [](Operation *o, Region *r) {
206 for (auto [idx, region] : llvm::enumerate(o->getRegions())) {
207 if (&region == r)
208 return idx;
209 }
210 llvm_unreachable("Region not child of its parent operation");
211 };
212 regionArgOfOperation.isUnique = false;
213 regionArgOfOperation.getArgIdx() = getRegionIndex(parent, r);
214 }
215
216 // next parent
217 o = parent;
218 }
219
220 // Determine whether a region-argument component is not needed
221 for (Component &c : components)
222 c.skipIf(c.isRegionArgOfOp && c.isUnique);
223
224 // Find runs of nested loops and determine each loop's depth in the loop nest
225 size_t numSurroundingLoops = 0;
226 for (Component &c : llvm::reverse(components)) {
227 if (c.skip)
228 continue;
229
230 // non-skipped multi-argument operands interrupt the loop nest
231 if (c.isRegionArgOfOp) {
232 numSurroundingLoops = 0;
233 continue;
234 }
235
236 // Multiple loops in a region means each of them is the outermost loop of a
237 // new loop nest
238 if (!c.isUnique)
239 numSurroundingLoops = 0;
240
241 c.getLoopDepth() = numSurroundingLoops;
242
243 // Next loop is surrounded by one more loop
244 if (isa<CanonicalLoopOp>(c.getContainerOp()))
245 numSurroundingLoops += 1;
246 }
247
248 // In loop nests, skip all but the innermost loop that contains the depth
249 // number
250 bool isLoopNest = false;
251 for (Component &c : components) {
252 if (c.skip || c.isRegionArgOfOp)
253 continue;
254
255 if (!isLoopNest && c.getLoopDepth() >= 1) {
256 // Innermost loop of a loop nest of at least two loops
257 isLoopNest = true;
258 } else if (isLoopNest) {
259 // Non-innermost loop of a loop nest
260 c.skipIf(c.isUnique);
261
262 // If there is no surrounding loop left, this must have been the outermost
263 // loop; leave loop-nest mode for the next iteration
264 if (c.getLoopDepth() == 0)
265 isLoopNest = false;
266 }
267 }
268
269 // Skip non-loop unambiguous regions (but they should interrupt loop nests, so
270 // we mark them as skipped only after computing loop nests)
271 for (Component &c : components)
272 c.skipIf(!c.isRegionArgOfOp && c.isUnique &&
273 !isa<CanonicalLoopOp>(c.getContainerOp()));
274
275 // Components can be skipped if they are already disambiguated by their parent
276 // (or does not have a parent)
277 bool newRegion = true;
278 for (Component &c : llvm::reverse(components)) {
279 c.skipIf(newRegion && c.isUnique);
280
281 // non-skipped components disambiguate unique children
282 if (!c.skip)
283 newRegion = true;
284
285 // ...except canonical loops that need a suffix for each nest
286 if (!c.isRegionArgOfOp && c.getContainerOp())
287 newRegion = false;
288 }
289
290 // Compile the nesting name string
291 SmallString<64> Name{prefix};
292 llvm::raw_svector_ostream NameOS(Name);
293 for (auto &c : llvm::reverse(components)) {
294 if (c.skip)
295 continue;
296
297 if (c.isRegionArgOfOp)
298 NameOS << "_r" << c.getArgIdx();
299 else if (c.getLoopDepth() >= 1)
300 NameOS << "_d" << c.getLoopDepth();
301 else
302 NameOS << "_s" << c.getOpPos();
303 }
304
305 return NameOS.str().str();
306}
307
308void OpenMPDialect::initialize() {
309 addOperations<
310#define GET_OP_LIST
311#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
312 >();
313 addAttributes<
314#define GET_ATTRDEF_LIST
315#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
316 >();
317 addTypes<
318#define GET_TYPEDEF_LIST
319#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
320 >();
321
322 declarePromisedInterface<ConvertToLLVMPatternInterface, OpenMPDialect>();
323
324 MemRefType::attachInterface<MemRefPointerLikeModel>(*getContext());
325 LLVM::LLVMPointerType::attachInterface<LLVMPointerPointerLikeModel>(
326 *getContext());
327
328 // Attach default offload module interface to module op to access
329 // offload functionality through
330 mlir::ModuleOp::attachInterface<mlir::omp::OffloadModuleDefaultModel>(
331 *getContext());
332
333 // Attach default declare target interfaces to operations which can be marked
334 // as declare target (Global Operations and Functions/Subroutines in dialects
335 // that Fortran (or other languages that lower to MLIR) translates too
336 mlir::LLVM::GlobalOp::attachInterface<
338 *getContext());
339 mlir::LLVM::LLVMFuncOp::attachInterface<
341 *getContext());
342 mlir::func::FuncOp::attachInterface<
344}
345
346//===----------------------------------------------------------------------===//
347// Parser and printer for Allocate Clause
348//===----------------------------------------------------------------------===//
349
350/// Parse an allocate clause with allocators and a list of operands with types.
351///
352/// allocate-operand-list :: = allocate-operand |
353/// allocator-operand `,` allocate-operand-list
354/// allocate-operand :: = ssa-id-and-type -> ssa-id-and-type
355/// ssa-id-and-type ::= ssa-id `:` type
356static ParseResult parseAllocateAndAllocator(
357 OpAsmParser &parser,
359 SmallVectorImpl<Type> &allocateTypes,
361 SmallVectorImpl<Type> &allocatorTypes) {
362
363 return parser.parseCommaSeparatedList([&]() {
365 Type type;
366 if (parser.parseOperand(operand) || parser.parseColonType(type))
367 return failure();
368 allocatorVars.push_back(operand);
369 allocatorTypes.push_back(type);
370 if (parser.parseArrow())
371 return failure();
372 if (parser.parseOperand(operand) || parser.parseColonType(type))
373 return failure();
374
375 allocateVars.push_back(operand);
376 allocateTypes.push_back(type);
377 return success();
378 });
379}
380
381/// Print allocate clause
383 OperandRange allocateVars,
384 TypeRange allocateTypes,
385 OperandRange allocatorVars,
386 TypeRange allocatorTypes) {
387 for (unsigned i = 0; i < allocateVars.size(); ++i) {
388 std::string separator = i == allocateVars.size() - 1 ? "" : ", ";
389 p << allocatorVars[i] << " : " << allocatorTypes[i] << " -> ";
390 p << allocateVars[i] << " : " << allocateTypes[i] << separator;
391 }
392}
393
394//===----------------------------------------------------------------------===//
395// Parser and printer for a clause attribute (StringEnumAttr)
396//===----------------------------------------------------------------------===//
397
398template <typename ClauseAttr>
399static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr) {
400 using ClauseT = decltype(std::declval<ClauseAttr>().getValue());
401 StringRef enumStr;
402 SMLoc loc = parser.getCurrentLocation();
403 if (parser.parseKeyword(&enumStr))
404 return failure();
405 if (std::optional<ClauseT> enumValue = symbolizeEnum<ClauseT>(enumStr)) {
406 attr = ClauseAttr::get(parser.getContext(), *enumValue);
407 return success();
408 }
409 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
410}
411
412template <typename ClauseAttr>
413static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr) {
414 p << stringifyEnum(attr.getValue());
415}
416
417//===----------------------------------------------------------------------===//
418// Parser and printer for Linear Clause
419//===----------------------------------------------------------------------===//
420
421/// linear ::= `linear` `(` linear-list `)`
422/// linear-list := linear-val | linear-val linear-list
423/// linear-val := ssa-id-and-type `=` ssa-id-and-type
424/// | `val` `(` ssa-id-and-type `=` ssa-id-and-type `)`
425/// | `ref` `(` ssa-id-and-type `=` ssa-id-and-type `)`
426/// | `uval` `(` ssa-id-and-type `=` ssa-id-and-type `)`
427static ParseResult parseLinearClause(
428 OpAsmParser &parser,
430 SmallVectorImpl<Type> &linearTypes,
432 SmallVectorImpl<Type> &linearStepTypes, ArrayAttr &linearModifiers) {
433 SmallVector<Attribute> modifiers;
434 auto result = parser.parseCommaSeparatedList([&]() {
436 Type type, stepType;
438
439 std::optional<omp::LinearModifier> linearModifier;
440 if (succeeded(parser.parseOptionalKeyword("val"))) {
441 linearModifier = omp::LinearModifier::val;
442 } else if (succeeded(parser.parseOptionalKeyword("ref"))) {
443 linearModifier = omp::LinearModifier::ref;
444 } else if (succeeded(parser.parseOptionalKeyword("uval"))) {
445 linearModifier = omp::LinearModifier::uval;
446 }
447
448 bool hasLinearModifierParens = linearModifier.has_value();
449 if (hasLinearModifierParens && parser.parseLParen())
450 return failure();
451
452 if (parser.parseOperand(var) || parser.parseColonType(type) ||
453 parser.parseEqual() || parser.parseOperand(stepVar) ||
454 parser.parseColonType(stepType))
455 return failure();
456
457 if (hasLinearModifierParens && parser.parseRParen())
458 return failure();
459
460 linearVars.push_back(var);
461 linearTypes.push_back(type);
462 linearStepVars.push_back(stepVar);
463 linearStepTypes.push_back(stepType);
464 if (linearModifier) {
465 modifiers.push_back(
466 omp::LinearModifierAttr::get(parser.getContext(), *linearModifier));
467 } else {
468 modifiers.push_back(UnitAttr::get(parser.getContext()));
469 }
470 return success();
471 });
472 if (failed(result))
473 return failure();
474 linearModifiers = ArrayAttr::get(parser.getContext(), modifiers);
475 return success();
476}
477
478/// Print Linear Clause
480 ValueRange linearVars, TypeRange linearTypes,
481 ValueRange linearStepVars, TypeRange stepVarTypes,
482 ArrayAttr linearModifiers) {
483 size_t linearVarsSize = linearVars.size();
484 for (unsigned i = 0; i < linearVarsSize; ++i) {
485 if (i != 0)
486 p << ", ";
487 // Print modifier keyword wrapper if present.
488 Attribute modAttr = linearModifiers ? linearModifiers[i] : nullptr;
489 auto mod = modAttr ? dyn_cast<omp::LinearModifierAttr>(modAttr) : nullptr;
490 if (mod) {
491 p << omp::stringifyLinearModifier(mod.getValue()) << "(";
492 }
493 p << linearVars[i] << " : " << linearTypes[i];
494 p << " = " << linearStepVars[i] << " : " << stepVarTypes[i];
495 if (mod)
496 p << ")";
497 }
498}
499
500//===----------------------------------------------------------------------===//
501// Verifier for Linear modifier
502//===----------------------------------------------------------------------===//
503
504/// OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or
505/// uval only on a declare simd directive."
506/// Also verifies that modifier count matches variable count.
507static LogicalResult
508verifyLinearModifiers(Operation *op, std::optional<ArrayAttr> linearModifiers,
509 OperandRange linearVars, bool isDeclareSimd = false) {
510 if (!linearModifiers)
511 return success();
512 if (linearModifiers->size() != linearVars.size())
513 return op->emitOpError()
514 << "expected as many linear modifiers as linear variables";
515 if (!isDeclareSimd) {
516 for (Attribute attr : *linearModifiers) {
517 if (!attr)
518 continue;
519 auto modAttr = dyn_cast<omp::LinearModifierAttr>(attr);
520 if (!modAttr)
521 continue;
522 omp::LinearModifier mod = modAttr.getValue();
523 if (mod == omp::LinearModifier::ref || mod == omp::LinearModifier::uval)
524 return op->emitOpError()
525 << "linear modifier '" << omp::stringifyLinearModifier(mod)
526 << "' may only be specified on a declare simd directive";
527 }
528 }
529 return success();
530}
531
532//===----------------------------------------------------------------------===//
533// Verifier for Nontemporal Clause
534//===----------------------------------------------------------------------===//
535
536static LogicalResult verifyNontemporalClause(Operation *op,
537 OperandRange nontemporalVars) {
538
539 // Check if each var is unique - OpenMP 5.0 -> 2.9.3.1 section
540 DenseSet<Value> nontemporalItems;
541 for (const auto &it : nontemporalVars)
542 if (!nontemporalItems.insert(it).second)
543 return op->emitOpError() << "nontemporal variable used more than once";
544
545 return success();
546}
547
548//===----------------------------------------------------------------------===//
549// Parser, verifier and printer for Aligned Clause
550//===----------------------------------------------------------------------===//
551static LogicalResult verifyAlignedClause(Operation *op,
552 std::optional<ArrayAttr> alignments,
553 OperandRange alignedVars) {
554 // Check if number of alignment values equals to number of aligned variables
555 if (!alignedVars.empty()) {
556 if (!alignments || alignments->size() != alignedVars.size())
557 return op->emitOpError()
558 << "expected as many alignment values as aligned variables";
559 } else {
560 if (alignments)
561 return op->emitOpError() << "unexpected alignment values attribute";
562 return success();
563 }
564
565 // Check if each var is aligned only once - OpenMP 4.5 -> 2.8.1 section
566 DenseSet<Value> alignedItems;
567 for (auto it : alignedVars)
568 if (!alignedItems.insert(it).second)
569 return op->emitOpError() << "aligned variable used more than once";
570
571 if (!alignments)
572 return success();
573
574 // Check if all alignment values are positive - OpenMP 4.5 -> 2.8.1 section
575 for (unsigned i = 0; i < (*alignments).size(); ++i) {
576 if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignments)[i])) {
577 if (intAttr.getValue().sle(0))
578 return op->emitOpError() << "alignment should be greater than 0";
579 } else {
580 return op->emitOpError() << "expected integer alignment";
581 }
582 }
583
584 return success();
585}
586
587/// aligned ::= `aligned` `(` aligned-list `)`
588/// aligned-list := aligned-val | aligned-val aligned-list
589/// aligned-val := ssa-id-and-type `->` alignment
590static ParseResult
593 SmallVectorImpl<Type> &alignedTypes,
594 ArrayAttr &alignmentsAttr) {
595 SmallVector<Attribute> alignmentVec;
596 if (failed(parser.parseCommaSeparatedList([&]() {
597 if (parser.parseOperand(alignedVars.emplace_back()) ||
598 parser.parseColonType(alignedTypes.emplace_back()) ||
599 parser.parseArrow() ||
600 parser.parseAttribute(alignmentVec.emplace_back())) {
601 return failure();
602 }
603 return success();
604 })))
605 return failure();
606 SmallVector<Attribute> alignments(alignmentVec.begin(), alignmentVec.end());
607 alignmentsAttr = ArrayAttr::get(parser.getContext(), alignments);
608 return success();
609}
610
611/// Print Aligned Clause
613 ValueRange alignedVars, TypeRange alignedTypes,
614 std::optional<ArrayAttr> alignments) {
615 for (unsigned i = 0; i < alignedVars.size(); ++i) {
616 if (i != 0)
617 p << ", ";
618 p << alignedVars[i] << " : " << alignedVars[i].getType();
619 p << " -> " << (*alignments)[i];
620 }
621}
622
623//===----------------------------------------------------------------------===//
624// Parser, printer and verifier for Schedule Clause
625//===----------------------------------------------------------------------===//
626
627static ParseResult
629 SmallVectorImpl<SmallString<12>> &modifiers) {
630 if (modifiers.size() > 2)
631 return parser.emitError(parser.getNameLoc()) << " unexpected modifier(s)";
632 for (const auto &mod : modifiers) {
633 // Translate the string. If it has no value, then it was not a valid
634 // modifier!
635 auto symbol = symbolizeScheduleModifier(mod);
636 if (!symbol)
637 return parser.emitError(parser.getNameLoc())
638 << " unknown modifier type: " << mod;
639 }
640
641 // If we have one modifier that is "simd", then stick a "none" modiifer in
642 // index 0.
643 if (modifiers.size() == 1) {
644 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd) {
645 modifiers.push_back(modifiers[0]);
646 modifiers[0] = stringifyScheduleModifier(ScheduleModifier::none);
647 }
648 } else if (modifiers.size() == 2) {
649 // If there are two modifier:
650 // First modifier should not be simd, second one should be simd
651 if (symbolizeScheduleModifier(modifiers[0]) == ScheduleModifier::simd ||
652 symbolizeScheduleModifier(modifiers[1]) != ScheduleModifier::simd)
653 return parser.emitError(parser.getNameLoc())
654 << " incorrect modifier order";
655 }
656 return success();
657}
658
659/// schedule ::= `schedule` `(` sched-list `)`
660/// sched-list ::= sched-val | sched-val sched-list |
661/// sched-val `,` sched-modifier
662/// sched-val ::= sched-with-chunk | sched-wo-chunk
663/// sched-with-chunk ::= sched-with-chunk-types (`=` ssa-id-and-type)?
664/// sched-with-chunk-types ::= `static` | `dynamic` | `guided`
665/// sched-wo-chunk ::= `auto` | `runtime`
666/// sched-modifier ::= sched-mod-val | sched-mod-val `,` sched-mod-val
667/// sched-mod-val ::= `monotonic` | `nonmonotonic` | `simd` | `none`
668static ParseResult
669parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr,
670 ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd,
671 std::optional<OpAsmParser::UnresolvedOperand> &chunkSize,
672 Type &chunkType) {
673 StringRef keyword;
674 if (parser.parseKeyword(&keyword))
675 return failure();
676 std::optional<mlir::omp::ClauseScheduleKind> schedule =
677 symbolizeClauseScheduleKind(keyword);
678 if (!schedule)
679 return parser.emitError(parser.getNameLoc()) << " expected schedule kind";
680
681 scheduleAttr = ClauseScheduleKindAttr::get(parser.getContext(), *schedule);
682 switch (*schedule) {
683 case ClauseScheduleKind::Static:
684 case ClauseScheduleKind::Dynamic:
685 case ClauseScheduleKind::Guided:
686 if (succeeded(parser.parseOptionalEqual())) {
687 chunkSize = OpAsmParser::UnresolvedOperand{};
688 if (parser.parseOperand(*chunkSize) || parser.parseColonType(chunkType))
689 return failure();
690 } else {
691 chunkSize = std::nullopt;
692 }
693 break;
694 case ClauseScheduleKind::Auto:
695 case ClauseScheduleKind::Runtime:
696 case ClauseScheduleKind::Distribute:
697 chunkSize = std::nullopt;
698 }
699
700 // If there is a comma, we have one or more modifiers..
702 while (succeeded(parser.parseOptionalComma())) {
703 StringRef mod;
704 if (parser.parseKeyword(&mod))
705 return failure();
706 modifiers.push_back(mod);
707 }
708
709 if (verifyScheduleModifiers(parser, modifiers))
710 return failure();
711
712 if (!modifiers.empty()) {
713 SMLoc loc = parser.getCurrentLocation();
714 if (std::optional<ScheduleModifier> mod =
715 symbolizeScheduleModifier(modifiers[0])) {
716 scheduleMod = ScheduleModifierAttr::get(parser.getContext(), *mod);
717 } else {
718 return parser.emitError(loc, "invalid schedule modifier");
719 }
720 // Only SIMD attribute is allowed here!
721 if (modifiers.size() > 1) {
722 assert(symbolizeScheduleModifier(modifiers[1]) == ScheduleModifier::simd);
723 scheduleSimd = UnitAttr::get(parser.getBuilder().getContext());
724 }
725 }
726
727 return success();
728}
729
730/// Print schedule clause
732 ClauseScheduleKindAttr scheduleKind,
733 ScheduleModifierAttr scheduleMod,
734 UnitAttr scheduleSimd, Value scheduleChunk,
735 Type scheduleChunkType) {
736 p << stringifyClauseScheduleKind(scheduleKind.getValue());
737 if (scheduleChunk)
738 p << " = " << scheduleChunk << " : " << scheduleChunk.getType();
739 if (scheduleMod)
740 p << ", " << stringifyScheduleModifier(scheduleMod.getValue());
741 if (scheduleSimd)
742 p << ", simd";
743}
744
745//===----------------------------------------------------------------------===//
746// Parser and printer for Order Clause
747//===----------------------------------------------------------------------===//
748
749// order ::= `order` `(` [order-modifier ':'] concurrent `)`
750// order-modifier ::= reproducible | unconstrained
751static ParseResult parseOrderClause(OpAsmParser &parser,
752 ClauseOrderKindAttr &order,
753 OrderModifierAttr &orderMod) {
754 StringRef enumStr;
755 SMLoc loc = parser.getCurrentLocation();
756 if (parser.parseKeyword(&enumStr))
757 return failure();
758 if (std::optional<OrderModifier> enumValue =
759 symbolizeOrderModifier(enumStr)) {
760 orderMod = OrderModifierAttr::get(parser.getContext(), *enumValue);
761 if (parser.parseOptionalColon())
762 return failure();
763 loc = parser.getCurrentLocation();
764 if (parser.parseKeyword(&enumStr))
765 return failure();
766 }
767 if (std::optional<ClauseOrderKind> enumValue =
768 symbolizeClauseOrderKind(enumStr)) {
769 order = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
770 return success();
771 }
772 return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
773}
774
776 ClauseOrderKindAttr order,
777 OrderModifierAttr orderMod) {
778 if (orderMod)
779 p << stringifyOrderModifier(orderMod.getValue()) << ":";
780 if (order)
781 p << stringifyClauseOrderKind(order.getValue());
782}
783
784template <typename ClauseTypeAttr, typename ClauseType>
785static ParseResult
786parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness,
787 std::optional<OpAsmParser::UnresolvedOperand> &operand,
788 Type &operandType,
789 std::optional<ClauseType> (*symbolizeClause)(StringRef),
790 StringRef clauseName) {
791 StringRef enumStr;
792 if (succeeded(parser.parseOptionalKeyword(&enumStr))) {
793 if (std::optional<ClauseType> enumValue = symbolizeClause(enumStr)) {
794 prescriptiveness = ClauseTypeAttr::get(parser.getContext(), *enumValue);
795 if (parser.parseComma())
796 return failure();
797 } else {
798 return parser.emitError(parser.getCurrentLocation())
799 << "invalid " << clauseName << " modifier : '" << enumStr << "'";
800 ;
801 }
802 }
803
805 if (succeeded(parser.parseOperand(var))) {
806 operand = var;
807 } else {
808 return parser.emitError(parser.getCurrentLocation())
809 << "expected " << clauseName << " operand";
810 }
811
812 if (operand.has_value()) {
813 if (parser.parseColonType(operandType))
814 return failure();
815 }
816
817 return success();
818}
819
820template <typename ClauseTypeAttr, typename ClauseType>
821static void
823 ClauseTypeAttr prescriptiveness, Value operand,
824 mlir::Type operandType,
825 StringRef (*stringifyClauseType)(ClauseType)) {
826
827 if (prescriptiveness)
828 p << stringifyClauseType(prescriptiveness.getValue()) << ", ";
829
830 if (operand)
831 p << operand << ": " << operandType;
832}
833
834//===----------------------------------------------------------------------===//
835// Parser and printer for grainsize Clause
836//===----------------------------------------------------------------------===//
837
838// grainsize ::= `grainsize` `(` [strict ':'] grain-size `)`
839static ParseResult
840parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod,
841 std::optional<OpAsmParser::UnresolvedOperand> &grainsize,
842 Type &grainsizeType) {
844 parser, grainsizeMod, grainsize, grainsizeType,
845 &symbolizeClauseGrainsizeType, "grainsize");
846}
847
849 ClauseGrainsizeTypeAttr grainsizeMod,
850 Value grainsize, mlir::Type grainsizeType) {
852 p, op, grainsizeMod, grainsize, grainsizeType,
853 &stringifyClauseGrainsizeType);
854}
855
856//===----------------------------------------------------------------------===//
857// Parser and printer for num_tasks Clause
858//===----------------------------------------------------------------------===//
859
860// numtask ::= `num_tasks` `(` [strict ':'] num-tasks `)`
861static ParseResult
862parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod,
863 std::optional<OpAsmParser::UnresolvedOperand> &numTasks,
864 Type &numTasksType) {
866 parser, numTasksMod, numTasks, numTasksType, &symbolizeClauseNumTasksType,
867 "num_tasks");
868}
869
871 ClauseNumTasksTypeAttr numTasksMod,
872 Value numTasks, mlir::Type numTasksType) {
874 p, op, numTasksMod, numTasks, numTasksType, &stringifyClauseNumTasksType);
875}
876
877//===----------------------------------------------------------------------===//
878// Parser and printer for Heap Alloc Clause
879//===----------------------------------------------------------------------===//
880
881/// operation ::= $in_type ( `(` $typeparams `)` )? ( `,` $shape )?
882static ParseResult parseHeapAllocClause(
883 OpAsmParser &parser, TypeAttr &inTypeAttr,
885 SmallVectorImpl<Type> &typeparamsTypes,
887 SmallVectorImpl<Type> &shapeTypes) {
888 mlir::Type inType;
889 if (parser.parseType(inType))
890 return mlir::failure();
891 inTypeAttr = TypeAttr::get(inType);
892
893 if (!parser.parseOptionalLParen()) {
894 // parse the LEN params of the derived type. (<params> : <types>)
895 if (parser.parseOperandList(typeparams, OpAsmParser::Delimiter::None) ||
896 parser.parseColonTypeList(typeparamsTypes) || parser.parseRParen())
897 return failure();
898 }
899
900 if (!parser.parseOptionalComma()) {
901 // parse size to scale by, vector of n dimensions of type index
903 return failure();
904
905 // TODO: This overrides the actual types of the operands, which might cause
906 // issues when they don't match. At the moment this is done in place of
907 // making the corresponding operand type `Variadic<Index>` because index
908 // types are lowered to I64 prior to LLVM IR translation.
909 shapeTypes.append(shape.size(), IndexType::get(parser.getContext()));
910 }
911
912 return success();
913}
914
916 TypeAttr inType, ValueRange typeparams,
917 TypeRange typeparamsTypes, ValueRange shape,
918 TypeRange shapeTypes) {
919 p << inType;
920 if (!typeparams.empty()) {
921 p << '(' << typeparams << " : " << typeparamsTypes << ')';
922 }
923 for (auto sh : shape) {
924 p << ", ";
925 p.printOperand(sh);
926 }
927}
928
929//===----------------------------------------------------------------------===//
930// Parser, printer and verify for dyn_groupprivate Clause
931//===----------------------------------------------------------------------===//
932
933static LogicalResult
934verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup,
935 FallbackModifierAttr fallback,
936 Value dynGroupprivateSize) {
937 if (!dynGroupprivateSize && (accessGroup || fallback))
938 return op->emitOpError("dyn_groupprivate modifiers require a size operand");
939
940 return success();
941}
942
943static ParseResult parseDynGroupprivateClause(
944 OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr,
945 FallbackModifierAttr &fallbackAttr,
946 std::optional<OpAsmParser::UnresolvedOperand> &dynGroupprivateSize,
947 Type &sizeType) {
948
949 bool parsedAccessGroup = false;
950 bool parsedFallback = false;
951 bool parsedSize = false;
952
953 return parser.parseCommaSeparatedList([&]() -> ParseResult {
954 // Parse AccessGroupModifier.
955 if (succeeded(parser.parseOptionalKeyword("cgroup"))) {
956 if (parsedAccessGroup)
957 return parser.emitError(parser.getCurrentLocation(),
958 "duplicate access group modifier");
959 accessGroupAttr = AccessGroupModifierAttr::get(
960 parser.getContext(), AccessGroupModifier::cgroup);
961 parsedAccessGroup = true;
962 return success();
963 }
964 // Parse FallbackModifier.
965 if (succeeded(parser.parseOptionalKeyword("fallback"))) {
966 if (parsedFallback)
967 return parser.emitError(parser.getCurrentLocation(),
968 "duplicate fallback modifier");
969 if (parser.parseLParen())
970 return parser.emitError(parser.getCurrentLocation(),
971 "expected '(' after 'fallback'");
972 llvm::StringRef fbKind;
973 if (parser.parseKeyword(&fbKind))
974 return parser.emitError(
975 parser.getCurrentLocation(),
976 "expected fallback modifier (abort/null/default_mem)");
977 std::optional<FallbackModifier> fbEnum;
978 if (fbKind == "abort")
979 fbEnum = FallbackModifier::abort;
980 else if (fbKind == "null")
981 fbEnum = FallbackModifier::null;
982 else if (fbKind == "default_mem")
983 fbEnum = FallbackModifier::default_mem;
984 else
985 return parser.emitError(parser.getCurrentLocation(),
986 "invalid fallback modifier '" + fbKind + "'");
987 fallbackAttr = FallbackModifierAttr::get(parser.getContext(), *fbEnum);
988 if (parser.parseRParen())
989 return parser.emitError(parser.getCurrentLocation(),
990 "expected ')' after fallback modifier");
991 parsedFallback = true;
992 return success();
993 }
994 // Parse size operand.
996 if (succeeded(parser.parseOperand(operand))) {
997 if (parsedSize)
998 return parser.emitError(parser.getCurrentLocation(),
999 "duplicate size operand");
1000 dynGroupprivateSize = operand;
1001 parsedSize = true;
1002 if (failed(parser.parseColon()) || failed(parser.parseType(sizeType)))
1003 return parser.emitError(parser.getCurrentLocation(),
1004 "expected ':' and type after size operand");
1005 return success();
1006 }
1007 return parser.emitError(parser.getCurrentLocation(),
1008 "expected dyn_groupprivate_size operand");
1009 });
1010}
1011
1013 AccessGroupModifierAttr modifierFirst,
1014 FallbackModifierAttr modifierSecond,
1015 Value dynGroupprivateSize,
1016 Type sizeType) {
1017
1018 bool needsComma = false;
1019
1020 if (modifierFirst) {
1021 printer << modifierFirst.getValue();
1022 needsComma = true;
1023 }
1024
1025 if (modifierSecond) {
1026 if (needsComma)
1027 printer << ", ";
1028 printer << "fallback(";
1029 printer << modifierSecond.getValue();
1030 printer << ")";
1031 needsComma = true;
1032 }
1033
1034 if (dynGroupprivateSize) {
1035 if (needsComma)
1036 printer << ", ";
1037 printer << dynGroupprivateSize << " : " << sizeType;
1038 }
1039}
1040
1041//===----------------------------------------------------------------------===//
1042// Parsers for operations including clauses that define entry block arguments.
1043//===----------------------------------------------------------------------===//
1044
1045namespace {
1046struct MapParseArgs {
1047 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1048 SmallVectorImpl<Type> &types;
1049 MapParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1050 SmallVectorImpl<Type> &types)
1051 : vars(vars), types(types) {}
1052};
1053struct PrivateParseArgs {
1054 llvm::SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1055 llvm::SmallVectorImpl<Type> &types;
1056 ArrayAttr &syms;
1057 UnitAttr &needsBarrier;
1058 DenseI64ArrayAttr *mapIndices;
1059 PrivateParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1060 SmallVectorImpl<Type> &types, ArrayAttr &syms,
1061 UnitAttr &needsBarrier,
1062 DenseI64ArrayAttr *mapIndices = nullptr)
1063 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1064 mapIndices(mapIndices) {}
1065};
1066
1067struct ReductionParseArgs {
1068 SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars;
1069 SmallVectorImpl<Type> &types;
1070 DenseBoolArrayAttr &byref;
1071 ArrayAttr &syms;
1072 ReductionModifierAttr *modifier;
1073 ReductionParseArgs(SmallVectorImpl<OpAsmParser::UnresolvedOperand> &vars,
1074 SmallVectorImpl<Type> &types, DenseBoolArrayAttr &byref,
1075 ArrayAttr &syms, ReductionModifierAttr *mod = nullptr)
1076 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1077};
1078
1079struct AllRegionParseArgs {
1080 std::optional<MapParseArgs> hasDeviceAddrArgs;
1081 std::optional<MapParseArgs> hostEvalArgs;
1082 std::optional<ReductionParseArgs> inReductionArgs;
1083 std::optional<MapParseArgs> mapArgs;
1084 std::optional<PrivateParseArgs> privateArgs;
1085 std::optional<ReductionParseArgs> reductionArgs;
1086 std::optional<ReductionParseArgs> taskReductionArgs;
1087 std::optional<MapParseArgs> useDeviceAddrArgs;
1088 std::optional<MapParseArgs> useDevicePtrArgs;
1089};
1090} // namespace
1091
1092static inline constexpr StringRef getPrivateNeedsBarrierSpelling() {
1093 return "private_barrier";
1094}
1095
1096static ParseResult parseClauseWithRegionArgs(
1097 OpAsmParser &parser,
1099 SmallVectorImpl<Type> &types,
1100 SmallVectorImpl<OpAsmParser::Argument> &regionPrivateArgs,
1101 ArrayAttr *symbols = nullptr, DenseI64ArrayAttr *mapIndices = nullptr,
1102 DenseBoolArrayAttr *byref = nullptr,
1103 ReductionModifierAttr *modifier = nullptr,
1104 UnitAttr *needsBarrier = nullptr) {
1106 SmallVector<int64_t> mapIndicesVec;
1107 SmallVector<bool> isByRefVec;
1108 unsigned regionArgOffset = regionPrivateArgs.size();
1109
1110 if (parser.parseLParen())
1111 return failure();
1112
1113 if (modifier && succeeded(parser.parseOptionalKeyword("mod"))) {
1114 StringRef enumStr;
1115 if (parser.parseColon() || parser.parseKeyword(&enumStr) ||
1116 parser.parseComma())
1117 return failure();
1118 std::optional<ReductionModifier> enumValue =
1119 symbolizeReductionModifier(enumStr);
1120 if (!enumValue.has_value())
1121 return failure();
1122 *modifier = ReductionModifierAttr::get(parser.getContext(), *enumValue);
1123 if (!*modifier)
1124 return failure();
1125 }
1126
1127 if (parser.parseCommaSeparatedList([&]() {
1128 if (byref)
1129 isByRefVec.push_back(
1130 parser.parseOptionalKeyword("byref").succeeded());
1131
1132 if (symbols && parser.parseAttribute(symbolVec.emplace_back()))
1133 return failure();
1134
1135 if (parser.parseOperand(operands.emplace_back()) ||
1136 parser.parseArrow() ||
1137 parser.parseArgument(regionPrivateArgs.emplace_back()))
1138 return failure();
1139
1140 if (mapIndices) {
1141 if (parser.parseOptionalLSquare().succeeded()) {
1142 if (parser.parseKeyword("map_idx") || parser.parseEqual() ||
1143 parser.parseInteger(mapIndicesVec.emplace_back()) ||
1144 parser.parseRSquare())
1145 return failure();
1146 } else {
1147 mapIndicesVec.push_back(-1);
1148 }
1149 }
1150
1151 return success();
1152 }))
1153 return failure();
1154
1155 if (parser.parseColon())
1156 return failure();
1157
1158 if (parser.parseCommaSeparatedList([&]() {
1159 if (parser.parseType(types.emplace_back()))
1160 return failure();
1161
1162 return success();
1163 }))
1164 return failure();
1165
1166 if (operands.size() != types.size())
1167 return failure();
1168
1169 if (parser.parseRParen())
1170 return failure();
1171
1172 if (needsBarrier) {
1174 .succeeded())
1175 *needsBarrier = mlir::UnitAttr::get(parser.getContext());
1176 }
1177
1178 auto *argsBegin = regionPrivateArgs.begin();
1179 MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
1180 argsBegin + regionArgOffset + types.size());
1181 for (auto [prv, type] : llvm::zip_equal(argsSubrange, types)) {
1182 prv.type = type;
1183 }
1184
1185 if (symbols) {
1186 SmallVector<Attribute> symbolAttrs(symbolVec.begin(), symbolVec.end());
1187 *symbols = ArrayAttr::get(parser.getContext(), symbolAttrs);
1188 }
1189
1190 if (!mapIndicesVec.empty())
1191 *mapIndices =
1192 mlir::DenseI64ArrayAttr::get(parser.getContext(), mapIndicesVec);
1193
1194 if (byref)
1195 *byref = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
1196
1197 return success();
1198}
1199
1200static ParseResult parseBlockArgClause(
1201 OpAsmParser &parser,
1203 StringRef keyword, std::optional<MapParseArgs> mapArgs) {
1204 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1205 if (!mapArgs)
1206 return failure();
1207
1208 if (failed(parseClauseWithRegionArgs(parser, mapArgs->vars, mapArgs->types,
1209 entryBlockArgs)))
1210 return failure();
1211 }
1212 return success();
1213}
1214
1215static ParseResult parseBlockArgClause(
1216 OpAsmParser &parser,
1218 StringRef keyword, std::optional<PrivateParseArgs> privateArgs) {
1219 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1220 if (!privateArgs)
1221 return failure();
1222
1223 if (failed(parseClauseWithRegionArgs(
1224 parser, privateArgs->vars, privateArgs->types, entryBlockArgs,
1225 &privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1226 /*modifier=*/nullptr, &privateArgs->needsBarrier)))
1227 return failure();
1228 }
1229 return success();
1230}
1231
1232static ParseResult parseBlockArgClause(
1233 OpAsmParser &parser,
1235 StringRef keyword, std::optional<ReductionParseArgs> reductionArgs) {
1236 if (succeeded(parser.parseOptionalKeyword(keyword))) {
1237 if (!reductionArgs)
1238 return failure();
1239 if (failed(parseClauseWithRegionArgs(
1240 parser, reductionArgs->vars, reductionArgs->types, entryBlockArgs,
1241 &reductionArgs->syms, /*mapIndices=*/nullptr, &reductionArgs->byref,
1242 reductionArgs->modifier)))
1243 return failure();
1244 }
1245 return success();
1246}
1247
1248static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region,
1249 AllRegionParseArgs args) {
1251
1252 if (failed(parseBlockArgClause(parser, entryBlockArgs, "has_device_addr",
1253 args.hasDeviceAddrArgs)))
1254 return parser.emitError(parser.getCurrentLocation())
1255 << "invalid `has_device_addr` format";
1256
1257 if (failed(parseBlockArgClause(parser, entryBlockArgs, "host_eval",
1258 args.hostEvalArgs)))
1259 return parser.emitError(parser.getCurrentLocation())
1260 << "invalid `host_eval` format";
1261
1262 if (failed(parseBlockArgClause(parser, entryBlockArgs, "in_reduction",
1263 args.inReductionArgs)))
1264 return parser.emitError(parser.getCurrentLocation())
1265 << "invalid `in_reduction` format";
1266
1267 if (failed(parseBlockArgClause(parser, entryBlockArgs, "map_entries",
1268 args.mapArgs)))
1269 return parser.emitError(parser.getCurrentLocation())
1270 << "invalid `map_entries` format";
1271
1272 if (failed(parseBlockArgClause(parser, entryBlockArgs, "private",
1273 args.privateArgs)))
1274 return parser.emitError(parser.getCurrentLocation())
1275 << "invalid `private` format";
1276
1277 if (failed(parseBlockArgClause(parser, entryBlockArgs, "reduction",
1278 args.reductionArgs)))
1279 return parser.emitError(parser.getCurrentLocation())
1280 << "invalid `reduction` format";
1281
1282 if (failed(parseBlockArgClause(parser, entryBlockArgs, "task_reduction",
1283 args.taskReductionArgs)))
1284 return parser.emitError(parser.getCurrentLocation())
1285 << "invalid `task_reduction` format";
1286
1287 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_addr",
1288 args.useDeviceAddrArgs)))
1289 return parser.emitError(parser.getCurrentLocation())
1290 << "invalid `use_device_addr` format";
1291
1292 if (failed(parseBlockArgClause(parser, entryBlockArgs, "use_device_ptr",
1293 args.useDevicePtrArgs)))
1294 return parser.emitError(parser.getCurrentLocation())
1295 << "invalid `use_device_addr` format";
1296
1297 return parser.parseRegion(region, entryBlockArgs);
1298}
1299
1300// These parseXyz functions correspond to the custom<Xyz> definitions
1301// in the .td file(s).
1302static ParseResult parseTargetOpRegion(
1303 OpAsmParser &parser, Region &region,
1305 SmallVectorImpl<Type> &hasDeviceAddrTypes,
1307 SmallVectorImpl<Type> &hostEvalTypes,
1309 SmallVectorImpl<Type> &inReductionTypes,
1310 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1312 SmallVectorImpl<Type> &mapTypes,
1314 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1315 UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps) {
1316 AllRegionParseArgs args;
1317 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1318 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1319 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1320 inReductionByref, inReductionSyms);
1321 args.mapArgs.emplace(mapVars, mapTypes);
1322 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1323 privateNeedsBarrier, &privateMaps);
1324 return parseBlockArgRegion(parser, region, args);
1325}
1326
1328 OpAsmParser &parser, Region &region,
1330 SmallVectorImpl<Type> &inReductionTypes,
1331 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1333 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1334 UnitAttr &privateNeedsBarrier) {
1335 AllRegionParseArgs args;
1336 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1337 inReductionByref, inReductionSyms);
1338 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1339 privateNeedsBarrier);
1340 return parseBlockArgRegion(parser, region, args);
1341}
1342
1344 OpAsmParser &parser, Region &region,
1346 SmallVectorImpl<Type> &inReductionTypes,
1347 DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms,
1349 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1350 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1352 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1353 ArrayAttr &reductionSyms) {
1354 AllRegionParseArgs args;
1355 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1356 inReductionByref, inReductionSyms);
1357 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1358 privateNeedsBarrier);
1359 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1360 reductionSyms, &reductionMod);
1361 return parseBlockArgRegion(parser, region, args);
1362}
1363
1364static ParseResult parsePrivateRegion(
1365 OpAsmParser &parser, Region &region,
1367 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1368 UnitAttr &privateNeedsBarrier) {
1369 AllRegionParseArgs args;
1370 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1371 privateNeedsBarrier);
1372 return parseBlockArgRegion(parser, region, args);
1373}
1374
1376 OpAsmParser &parser, Region &region,
1378 llvm::SmallVectorImpl<Type> &privateTypes, ArrayAttr &privateSyms,
1379 UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod,
1381 SmallVectorImpl<Type> &reductionTypes, DenseBoolArrayAttr &reductionByref,
1382 ArrayAttr &reductionSyms) {
1383 AllRegionParseArgs args;
1384 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1385 privateNeedsBarrier);
1386 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1387 reductionSyms, &reductionMod);
1388 return parseBlockArgRegion(parser, region, args);
1389}
1390
1391static ParseResult parseTaskReductionRegion(
1392 OpAsmParser &parser, Region &region,
1394 SmallVectorImpl<Type> &taskReductionTypes,
1395 DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms) {
1396 AllRegionParseArgs args;
1397 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1398 taskReductionByref, taskReductionSyms);
1399 return parseBlockArgRegion(parser, region, args);
1400}
1401
1403 OpAsmParser &parser, Region &region,
1405 SmallVectorImpl<Type> &useDeviceAddrTypes,
1407 SmallVectorImpl<Type> &useDevicePtrTypes) {
1408 AllRegionParseArgs args;
1409 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1410 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1411 return parseBlockArgRegion(parser, region, args);
1412}
1413
1414//===----------------------------------------------------------------------===//
1415// Printers for operations including clauses that define entry block arguments.
1416//===----------------------------------------------------------------------===//
1417
1418namespace {
1419struct MapPrintArgs {
1420 ValueRange vars;
1421 TypeRange types;
1422 MapPrintArgs(ValueRange vars, TypeRange types) : vars(vars), types(types) {}
1423};
1424struct PrivatePrintArgs {
1425 ValueRange vars;
1426 TypeRange types;
1427 ArrayAttr syms;
1428 UnitAttr needsBarrier;
1429 DenseI64ArrayAttr mapIndices;
1430 PrivatePrintArgs(ValueRange vars, TypeRange types, ArrayAttr syms,
1431 UnitAttr needsBarrier, DenseI64ArrayAttr mapIndices)
1432 : vars(vars), types(types), syms(syms), needsBarrier(needsBarrier),
1433 mapIndices(mapIndices) {}
1434};
1435struct ReductionPrintArgs {
1436 ValueRange vars;
1437 TypeRange types;
1438 DenseBoolArrayAttr byref;
1439 ArrayAttr syms;
1440 ReductionModifierAttr modifier;
1441 ReductionPrintArgs(ValueRange vars, TypeRange types, DenseBoolArrayAttr byref,
1442 ArrayAttr syms, ReductionModifierAttr mod = nullptr)
1443 : vars(vars), types(types), byref(byref), syms(syms), modifier(mod) {}
1444};
1445struct AllRegionPrintArgs {
1446 std::optional<MapPrintArgs> hasDeviceAddrArgs;
1447 std::optional<MapPrintArgs> hostEvalArgs;
1448 std::optional<ReductionPrintArgs> inReductionArgs;
1449 std::optional<MapPrintArgs> mapArgs;
1450 std::optional<PrivatePrintArgs> privateArgs;
1451 std::optional<ReductionPrintArgs> reductionArgs;
1452 std::optional<ReductionPrintArgs> taskReductionArgs;
1453 std::optional<MapPrintArgs> useDeviceAddrArgs;
1454 std::optional<MapPrintArgs> useDevicePtrArgs;
1455};
1456} // namespace
1457
1459 OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1460 ValueRange argsSubrange, ValueRange operands, TypeRange types,
1461 ArrayAttr symbols = nullptr, DenseI64ArrayAttr mapIndices = nullptr,
1462 DenseBoolArrayAttr byref = nullptr,
1463 ReductionModifierAttr modifier = nullptr, UnitAttr needsBarrier = nullptr) {
1464 if (argsSubrange.empty())
1465 return;
1466
1467 p << clauseName << "(";
1468
1469 if (modifier)
1470 p << "mod: " << stringifyReductionModifier(modifier.getValue()) << ", ";
1471
1472 if (!symbols) {
1473 llvm::SmallVector<Attribute> values(operands.size(), nullptr);
1474 symbols = ArrayAttr::get(ctx, values);
1475 }
1476
1477 if (!mapIndices) {
1478 llvm::SmallVector<int64_t> values(operands.size(), -1);
1479 mapIndices = DenseI64ArrayAttr::get(ctx, values);
1480 }
1481
1482 if (!byref) {
1483 mlir::SmallVector<bool> values(operands.size(), false);
1484 byref = DenseBoolArrayAttr::get(ctx, values);
1485 }
1486
1487 llvm::interleaveComma(llvm::zip_equal(operands, argsSubrange, symbols,
1488 mapIndices.asArrayRef(),
1489 byref.asArrayRef()),
1490 p, [&p](auto t) {
1491 auto [op, arg, sym, map, isByRef] = t;
1492 if (isByRef)
1493 p << "byref ";
1494 if (sym)
1495 p << sym << " ";
1496
1497 p << op << " -> " << arg;
1498
1499 if (map != -1)
1500 p << " [map_idx=" << map << "]";
1501 });
1502 p << " : ";
1503 llvm::interleaveComma(types, p);
1504 p << ") ";
1505
1506 if (needsBarrier)
1507 p << getPrivateNeedsBarrierSpelling() << " ";
1508}
1509
1511 StringRef clauseName, ValueRange argsSubrange,
1512 std::optional<MapPrintArgs> mapArgs) {
1513 if (mapArgs)
1514 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange, mapArgs->vars,
1515 mapArgs->types);
1516}
1517
1519 StringRef clauseName, ValueRange argsSubrange,
1520 std::optional<PrivatePrintArgs> privateArgs) {
1521 if (privateArgs)
1523 p, ctx, clauseName, argsSubrange, privateArgs->vars, privateArgs->types,
1524 privateArgs->syms, privateArgs->mapIndices, /*byref=*/nullptr,
1525 /*modifier=*/nullptr, privateArgs->needsBarrier);
1526}
1527
1528static void
1529printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName,
1530 ValueRange argsSubrange,
1531 std::optional<ReductionPrintArgs> reductionArgs) {
1532 if (reductionArgs)
1533 printClauseWithRegionArgs(p, ctx, clauseName, argsSubrange,
1534 reductionArgs->vars, reductionArgs->types,
1535 reductionArgs->syms, /*mapIndices=*/nullptr,
1536 reductionArgs->byref, reductionArgs->modifier);
1537}
1538
1540 const AllRegionPrintArgs &args) {
1541 auto iface = llvm::cast<mlir::omp::BlockArgOpenMPOpInterface>(op);
1542 MLIRContext *ctx = op->getContext();
1543
1544 printBlockArgClause(p, ctx, "has_device_addr",
1545 iface.getHasDeviceAddrBlockArgs(),
1546 args.hasDeviceAddrArgs);
1547 printBlockArgClause(p, ctx, "host_eval", iface.getHostEvalBlockArgs(),
1548 args.hostEvalArgs);
1549 printBlockArgClause(p, ctx, "in_reduction", iface.getInReductionBlockArgs(),
1550 args.inReductionArgs);
1551 printBlockArgClause(p, ctx, "map_entries", iface.getMapBlockArgs(),
1552 args.mapArgs);
1553 printBlockArgClause(p, ctx, "private", iface.getPrivateBlockArgs(),
1554 args.privateArgs);
1555 printBlockArgClause(p, ctx, "reduction", iface.getReductionBlockArgs(),
1556 args.reductionArgs);
1557 printBlockArgClause(p, ctx, "task_reduction",
1558 iface.getTaskReductionBlockArgs(),
1559 args.taskReductionArgs);
1560 printBlockArgClause(p, ctx, "use_device_addr",
1561 iface.getUseDeviceAddrBlockArgs(),
1562 args.useDeviceAddrArgs);
1563 printBlockArgClause(p, ctx, "use_device_ptr",
1564 iface.getUseDevicePtrBlockArgs(), args.useDevicePtrArgs);
1565
1566 p.printRegion(region, /*printEntryBlockArgs=*/false);
1567}
1568
1569// These parseXyz functions correspond to the custom<Xyz> definitions
1570// in the .td file(s).
1572 OpAsmPrinter &p, Operation *op, Region &region,
1573 ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes,
1574 ValueRange hostEvalVars, TypeRange hostEvalTypes,
1575 ValueRange inReductionVars, TypeRange inReductionTypes,
1576 DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms,
1577 ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars,
1578 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1579 DenseI64ArrayAttr privateMaps) {
1580 AllRegionPrintArgs args;
1581 args.hasDeviceAddrArgs.emplace(hasDeviceAddrVars, hasDeviceAddrTypes);
1582 args.hostEvalArgs.emplace(hostEvalVars, hostEvalTypes);
1583 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1584 inReductionByref, inReductionSyms);
1585 args.mapArgs.emplace(mapVars, mapTypes);
1586 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1587 privateNeedsBarrier, privateMaps);
1588 printBlockArgRegion(p, op, region, args);
1589}
1590
1592 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1593 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1594 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1595 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier) {
1596 AllRegionPrintArgs args;
1597 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1598 inReductionByref, inReductionSyms);
1599 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1600 privateNeedsBarrier,
1601 /*mapIndices=*/nullptr);
1602 printBlockArgRegion(p, op, region, args);
1603}
1604
1606 OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars,
1607 TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref,
1608 ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes,
1609 ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1610 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1611 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1612 ArrayAttr reductionSyms) {
1613 AllRegionPrintArgs args;
1614 args.inReductionArgs.emplace(inReductionVars, inReductionTypes,
1615 inReductionByref, inReductionSyms);
1616 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1617 privateNeedsBarrier,
1618 /*mapIndices=*/nullptr);
1619 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1620 reductionSyms, reductionMod);
1621 printBlockArgRegion(p, op, region, args);
1622}
1623
1625 ValueRange privateVars, TypeRange privateTypes,
1626 ArrayAttr privateSyms,
1627 UnitAttr privateNeedsBarrier) {
1628 AllRegionPrintArgs args;
1629 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1630 privateNeedsBarrier,
1631 /*mapIndices=*/nullptr);
1632 printBlockArgRegion(p, op, region, args);
1633}
1634
1636 OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars,
1637 TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier,
1638 ReductionModifierAttr reductionMod, ValueRange reductionVars,
1639 TypeRange reductionTypes, DenseBoolArrayAttr reductionByref,
1640 ArrayAttr reductionSyms) {
1641 AllRegionPrintArgs args;
1642 args.privateArgs.emplace(privateVars, privateTypes, privateSyms,
1643 privateNeedsBarrier,
1644 /*mapIndices=*/nullptr);
1645 args.reductionArgs.emplace(reductionVars, reductionTypes, reductionByref,
1646 reductionSyms, reductionMod);
1647 printBlockArgRegion(p, op, region, args);
1648}
1649
1651 Region &region,
1652 ValueRange taskReductionVars,
1653 TypeRange taskReductionTypes,
1654 DenseBoolArrayAttr taskReductionByref,
1655 ArrayAttr taskReductionSyms) {
1656 AllRegionPrintArgs args;
1657 args.taskReductionArgs.emplace(taskReductionVars, taskReductionTypes,
1658 taskReductionByref, taskReductionSyms);
1659 printBlockArgRegion(p, op, region, args);
1660}
1661
1663 Region &region,
1664 ValueRange useDeviceAddrVars,
1665 TypeRange useDeviceAddrTypes,
1666 ValueRange useDevicePtrVars,
1667 TypeRange useDevicePtrTypes) {
1668 AllRegionPrintArgs args;
1669 args.useDeviceAddrArgs.emplace(useDeviceAddrVars, useDeviceAddrTypes);
1670 args.useDevicePtrArgs.emplace(useDevicePtrVars, useDevicePtrTypes);
1671 printBlockArgRegion(p, op, region, args);
1672}
1673
1674template <typename ParsePrefixFn>
1675static ParseResult parseSplitIteratedList(
1676 OpAsmParser &parser,
1678 SmallVectorImpl<Type> &iteratedTypes,
1680 SmallVectorImpl<Type> &plainTypes, ParsePrefixFn &&parsePrefix) {
1681
1682 return parser.parseCommaSeparatedList([&]() -> ParseResult {
1683 if (failed(parsePrefix()))
1684 return failure();
1685
1687 Type ty;
1688 if (parser.parseOperand(v) || parser.parseColonType(ty))
1689 return failure();
1690
1691 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1692 iteratedVars.push_back(v);
1693 iteratedTypes.push_back(ty);
1694 } else {
1695 plainVars.push_back(v);
1696 plainTypes.push_back(ty);
1697 }
1698 return success();
1699 });
1700}
1701
1702template <typename PrintPrefixFn>
1704 TypeRange iteratedTypes,
1705 ValueRange plainVars, TypeRange plainTypes,
1706 PrintPrefixFn &&printPrefixForPlain,
1707 PrintPrefixFn &&printPrefixForIterated) {
1708
1709 bool first = true;
1710 auto emit = [&](Value v, Type t, auto &&printPrefix) {
1711 if (!first)
1712 p << ", ";
1713 printPrefix(v, t);
1714 p << v << " : " << t;
1715 first = false;
1716 };
1717
1718 for (unsigned i = 0; i < iteratedVars.size(); ++i)
1719 emit(iteratedVars[i], iteratedTypes[i], printPrefixForIterated);
1720 for (unsigned i = 0; i < plainVars.size(); ++i)
1721 emit(plainVars[i], plainTypes[i], printPrefixForPlain);
1722}
1723
1724/// Verifies Reduction Clause
1725static LogicalResult
1726verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductionSyms,
1727 OperandRange reductionVars,
1728 std::optional<ArrayRef<bool>> reductionByref) {
1729 if (!reductionVars.empty()) {
1730 if (!reductionSyms || reductionSyms->size() != reductionVars.size())
1731 return op->emitOpError()
1732 << "expected as many reduction symbol references "
1733 "as reduction variables";
1734 if (reductionByref && reductionByref->size() != reductionVars.size())
1735 return op->emitError() << "expected as many reduction variable by "
1736 "reference attributes as reduction variables";
1737 } else {
1738 if (reductionSyms)
1739 return op->emitOpError() << "unexpected reduction symbol references";
1740 return success();
1741 }
1742
1743 // TODO: The followings should be done in
1744 // SymbolUserOpInterface::verifySymbolUses.
1745 DenseSet<Value> accumulators;
1746 for (auto args : llvm::zip(reductionVars, *reductionSyms)) {
1747 Value accum = std::get<0>(args);
1748
1749 if (!accumulators.insert(accum).second)
1750 return op->emitOpError() << "accumulator variable used more than once";
1751
1752 Type varType = accum.getType();
1753 auto symbolRef = llvm::cast<SymbolRefAttr>(std::get<1>(args));
1754 auto decl =
1756 if (!decl)
1757 return op->emitOpError() << "expected symbol reference " << symbolRef
1758 << " to point to a reduction declaration";
1759
1760 if (decl.getAccumulatorType() && decl.getAccumulatorType() != varType)
1761 return op->emitOpError()
1762 << "expected accumulator (" << varType
1763 << ") to be the same type as reduction declaration ("
1764 << decl.getAccumulatorType() << ")";
1765 }
1766
1767 return success();
1768}
1769
1770//===----------------------------------------------------------------------===//
1771// Parser, printer and verifier for Copyprivate
1772//===----------------------------------------------------------------------===//
1773
1774/// copyprivate-entry-list ::= copyprivate-entry
1775/// | copyprivate-entry-list `,` copyprivate-entry
1776/// copyprivate-entry ::= ssa-id `->` symbol-ref `:` type
1777static ParseResult parseCopyprivate(
1778 OpAsmParser &parser,
1780 SmallVectorImpl<Type> &copyprivateTypes, ArrayAttr &copyprivateSyms) {
1782 if (failed(parser.parseCommaSeparatedList([&]() {
1783 if (parser.parseOperand(copyprivateVars.emplace_back()) ||
1784 parser.parseArrow() ||
1785 parser.parseAttribute(symsVec.emplace_back()) ||
1786 parser.parseColonType(copyprivateTypes.emplace_back()))
1787 return failure();
1788 return success();
1789 })))
1790 return failure();
1791 SmallVector<Attribute> syms(symsVec.begin(), symsVec.end());
1792 copyprivateSyms = ArrayAttr::get(parser.getContext(), syms);
1793 return success();
1794}
1795
1796/// Print Copyprivate clause
1798 OperandRange copyprivateVars,
1799 TypeRange copyprivateTypes,
1800 std::optional<ArrayAttr> copyprivateSyms) {
1801 if (!copyprivateSyms.has_value())
1802 return;
1803 llvm::interleaveComma(
1804 llvm::zip(copyprivateVars, *copyprivateSyms, copyprivateTypes), p,
1805 [&](const auto &args) {
1806 p << std::get<0>(args) << " -> " << std::get<1>(args) << " : "
1807 << std::get<2>(args);
1808 });
1809}
1810
1811/// Verifies CopyPrivate Clause
1812static LogicalResult
1814 std::optional<ArrayAttr> copyprivateSyms) {
1815 size_t copyprivateSymsSize =
1816 copyprivateSyms.has_value() ? copyprivateSyms->size() : 0;
1817 if (copyprivateSymsSize != copyprivateVars.size())
1818 return op->emitOpError() << "inconsistent number of copyprivate vars (= "
1819 << copyprivateVars.size()
1820 << ") and functions (= " << copyprivateSymsSize
1821 << "), both must be equal";
1822 if (!copyprivateSyms.has_value())
1823 return success();
1824
1825 for (auto copyprivateVarAndSym :
1826 llvm::zip(copyprivateVars, *copyprivateSyms)) {
1827 auto symbolRef =
1828 llvm::cast<SymbolRefAttr>(std::get<1>(copyprivateVarAndSym));
1829 std::optional<std::variant<mlir::func::FuncOp, mlir::LLVM::LLVMFuncOp>>
1830 funcOp;
1831 if (mlir::func::FuncOp mlirFuncOp =
1833 symbolRef))
1834 funcOp = mlirFuncOp;
1835 else if (mlir::LLVM::LLVMFuncOp llvmFuncOp =
1837 op, symbolRef))
1838 funcOp = llvmFuncOp;
1839
1840 auto getNumArguments = [&] {
1841 return std::visit([](auto &f) { return f.getNumArguments(); }, *funcOp);
1842 };
1843
1844 auto getArgumentType = [&](unsigned i) {
1845 return std::visit([i](auto &f) { return f.getArgumentTypes()[i]; },
1846 *funcOp);
1847 };
1848
1849 if (!funcOp)
1850 return op->emitOpError() << "expected symbol reference " << symbolRef
1851 << " to point to a copy function";
1852
1853 if (getNumArguments() != 2)
1854 return op->emitOpError()
1855 << "expected copy function " << symbolRef << " to have 2 operands";
1856
1857 Type argTy = getArgumentType(0);
1858 if (argTy != getArgumentType(1))
1859 return op->emitOpError() << "expected copy function " << symbolRef
1860 << " arguments to have the same type";
1861
1862 Type varType = std::get<0>(copyprivateVarAndSym).getType();
1863 if (argTy != varType)
1864 return op->emitOpError()
1865 << "expected copy function arguments' type (" << argTy
1866 << ") to be the same as copyprivate variable's type (" << varType
1867 << ")";
1868 }
1869
1870 return success();
1871}
1872
1873//===----------------------------------------------------------------------===//
1874// Parser, printer and verifier for DependVarList
1875//===----------------------------------------------------------------------===//
1876
1877/// depend-entry-list ::= depend-entry
1878/// | depend-entry-list `,` depend-entry
1879/// depend-entry ::= depend-kind `->` ssa-id `:` type
1880/// | depend-kind `->` ssa-id `:` iterated-type
1881static ParseResult parseDependVarList(
1882 OpAsmParser &parser,
1884 SmallVectorImpl<Type> &dependTypes, ArrayAttr &dependKinds,
1886 SmallVectorImpl<Type> &iteratedTypes, ArrayAttr &iteratedKinds) {
1889 if (failed(parser.parseCommaSeparatedList([&]() {
1890 StringRef keyword;
1891 OpAsmParser::UnresolvedOperand operand;
1892 Type ty;
1893 if (parser.parseKeyword(&keyword) || parser.parseArrow() ||
1894 parser.parseOperand(operand) || parser.parseColonType(ty))
1895 return failure();
1896 std::optional<ClauseTaskDepend> keywordDepend =
1897 symbolizeClauseTaskDepend(keyword);
1898 if (!keywordDepend)
1899 return failure();
1900 auto kindAttr =
1901 ClauseTaskDependAttr::get(parser.getContext(), *keywordDepend);
1902 if (llvm::isa<mlir::omp::IteratedType>(ty)) {
1903 iteratedVars.push_back(operand);
1904 iteratedTypes.push_back(ty);
1905 iterKindsVec.push_back(kindAttr);
1906 } else {
1907 dependVars.push_back(operand);
1908 dependTypes.push_back(ty);
1909 kindsVec.push_back(kindAttr);
1910 }
1911 return success();
1912 })))
1913 return failure();
1914 SmallVector<Attribute> kinds(kindsVec.begin(), kindsVec.end());
1915 dependKinds = ArrayAttr::get(parser.getContext(), kinds);
1916 SmallVector<Attribute> iterKinds(iterKindsVec.begin(), iterKindsVec.end());
1917 iteratedKinds = ArrayAttr::get(parser.getContext(), iterKinds);
1918 return success();
1919}
1920
1921/// Print Depend clause
1923 OperandRange dependVars, TypeRange dependTypes,
1924 std::optional<ArrayAttr> dependKinds,
1925 OperandRange iteratedVars,
1926 TypeRange iteratedTypes,
1927 std::optional<ArrayAttr> iteratedKinds) {
1928 bool first = true;
1929 auto printEntries = [&](OperandRange vars, TypeRange types,
1930 std::optional<ArrayAttr> kinds) {
1931 for (unsigned i = 0, e = vars.size(); i < e; ++i) {
1932 if (!first)
1933 p << ", ";
1934 p << stringifyClauseTaskDepend(
1935 llvm::cast<mlir::omp::ClauseTaskDependAttr>((*kinds)[i])
1936 .getValue())
1937 << " -> " << vars[i] << " : " << types[i];
1938 first = false;
1939 }
1940 };
1941 printEntries(dependVars, dependTypes, dependKinds);
1942 printEntries(iteratedVars, iteratedTypes, iteratedKinds);
1943}
1944
1945/// Verifies Depend clause
1946static LogicalResult verifyDependVarList(Operation *op,
1947 std::optional<ArrayAttr> dependKinds,
1948 OperandRange dependVars,
1949 std::optional<ArrayAttr> iteratedKinds,
1950 OperandRange iteratedVars) {
1951 if (!dependVars.empty()) {
1952 if (!dependKinds || dependKinds->size() != dependVars.size())
1953 return op->emitOpError() << "expected as many depend values"
1954 " as depend variables";
1955 } else {
1956 if (dependKinds && !dependKinds->empty())
1957 return op->emitOpError() << "unexpected depend values";
1958 }
1959
1960 if (!iteratedVars.empty()) {
1961 if (!iteratedKinds || iteratedKinds->size() != iteratedVars.size())
1962 return op->emitOpError() << "expected as many depend iterated values"
1963 " as depend iterated variables";
1964 } else {
1965 if (iteratedKinds && !iteratedKinds->empty())
1966 return op->emitOpError() << "unexpected depend iterated values";
1967 }
1968
1969 return success();
1970}
1971
1972//===----------------------------------------------------------------------===//
1973// Parser, printer and verifier for Synchronization Hint (2.17.12)
1974//===----------------------------------------------------------------------===//
1975
1976/// Parses a Synchronization Hint clause. The value of hint is an integer
1977/// which is a combination of different hints from `omp_sync_hint_t`.
1978///
1979/// hint-clause = `hint` `(` hint-value `)`
1980static ParseResult parseSynchronizationHint(OpAsmParser &parser,
1981 IntegerAttr &hintAttr) {
1982 StringRef hintKeyword;
1983 int64_t hint = 0;
1984 if (succeeded(parser.parseOptionalKeyword("none"))) {
1985 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), 0);
1986 return success();
1987 }
1988 auto parseKeyword = [&]() -> ParseResult {
1989 if (failed(parser.parseKeyword(&hintKeyword)))
1990 return failure();
1991 if (hintKeyword == "uncontended")
1992 hint |= 1;
1993 else if (hintKeyword == "contended")
1994 hint |= 2;
1995 else if (hintKeyword == "nonspeculative")
1996 hint |= 4;
1997 else if (hintKeyword == "speculative")
1998 hint |= 8;
1999 else
2000 return parser.emitError(parser.getCurrentLocation())
2001 << hintKeyword << " is not a valid hint";
2002 return success();
2003 };
2004 if (parser.parseCommaSeparatedList(parseKeyword))
2005 return failure();
2006 hintAttr = IntegerAttr::get(parser.getBuilder().getI64Type(), hint);
2007 return success();
2008}
2009
2010/// Prints a Synchronization Hint clause
2012 IntegerAttr hintAttr) {
2013 int64_t hint = hintAttr.getInt();
2014
2015 if (hint == 0) {
2016 p << "none";
2017 return;
2018 }
2019
2020 // Helper function to get n-th bit from the right end of `value`
2021 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
2022
2023 bool uncontended = bitn(hint, 0);
2024 bool contended = bitn(hint, 1);
2025 bool nonspeculative = bitn(hint, 2);
2026 bool speculative = bitn(hint, 3);
2027
2029 if (uncontended)
2030 hints.push_back("uncontended");
2031 if (contended)
2032 hints.push_back("contended");
2033 if (nonspeculative)
2034 hints.push_back("nonspeculative");
2035 if (speculative)
2036 hints.push_back("speculative");
2037
2038 llvm::interleaveComma(hints, p);
2039}
2040
2041/// Verifies a synchronization hint clause
2042static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint) {
2043
2044 // Helper function to get n-th bit from the right end of `value`
2045 auto bitn = [](int value, int n) -> bool { return value & (1 << n); };
2046
2047 bool uncontended = bitn(hint, 0);
2048 bool contended = bitn(hint, 1);
2049 bool nonspeculative = bitn(hint, 2);
2050 bool speculative = bitn(hint, 3);
2051
2052 if (uncontended && contended)
2053 return op->emitOpError() << "the hints omp_sync_hint_uncontended and "
2054 "omp_sync_hint_contended cannot be combined";
2055 if (nonspeculative && speculative)
2056 return op->emitOpError() << "the hints omp_sync_hint_nonspeculative and "
2057 "omp_sync_hint_speculative cannot be combined.";
2058 return success();
2059}
2060
2061//===----------------------------------------------------------------------===//
2062// Parser, printer and verifier for Target
2063//===----------------------------------------------------------------------===//
2064
2065// Helper function to get bitwise AND of `value` and 'flag' then return it as a
2066// boolean
2067static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag) {
2068 return (value & flag) == flag;
2069}
2070
2071/// Parses a map_entries map type from a string format back into its numeric
2072/// value.
2073///
2074/// map-clause = `map_clauses ( ( `(` `always, `? `implicit, `? `ompx_hold, `?
2075/// `close, `? `present, `? ( `to` | `from` | `delete` `)` )+ `)` )
2076static ParseResult parseMapClause(OpAsmParser &parser,
2077 ClauseMapFlagsAttr &mapType) {
2078 ClauseMapFlags mapTypeBits = ClauseMapFlags::none;
2079 // This simply verifies the correct keyword is read in, the
2080 // keyword itself is stored inside of the operation
2081 auto parseTypeAndMod = [&]() -> ParseResult {
2082 StringRef mapTypeMod;
2083 if (parser.parseKeyword(&mapTypeMod))
2084 return failure();
2085
2086 if (mapTypeMod == "always")
2087 mapTypeBits |= ClauseMapFlags::always;
2088
2089 if (mapTypeMod == "implicit")
2090 mapTypeBits |= ClauseMapFlags::implicit;
2091
2092 if (mapTypeMod == "ompx_hold")
2093 mapTypeBits |= ClauseMapFlags::ompx_hold;
2094
2095 if (mapTypeMod == "close")
2096 mapTypeBits |= ClauseMapFlags::close;
2097
2098 if (mapTypeMod == "present")
2099 mapTypeBits |= ClauseMapFlags::present;
2100
2101 if (mapTypeMod == "to")
2102 mapTypeBits |= ClauseMapFlags::to;
2103
2104 if (mapTypeMod == "from")
2105 mapTypeBits |= ClauseMapFlags::from;
2106
2107 if (mapTypeMod == "tofrom")
2108 mapTypeBits |= ClauseMapFlags::to | ClauseMapFlags::from;
2109
2110 if (mapTypeMod == "delete")
2111 mapTypeBits |= ClauseMapFlags::del;
2112
2113 if (mapTypeMod == "storage")
2114 mapTypeBits |= ClauseMapFlags::storage;
2115
2116 if (mapTypeMod == "return_param")
2117 mapTypeBits |= ClauseMapFlags::return_param;
2118
2119 if (mapTypeMod == "private")
2120 mapTypeBits |= ClauseMapFlags::priv;
2121
2122 if (mapTypeMod == "literal")
2123 mapTypeBits |= ClauseMapFlags::literal;
2124
2125 if (mapTypeMod == "attach")
2126 mapTypeBits |= ClauseMapFlags::attach;
2127
2128 if (mapTypeMod == "attach_always")
2129 mapTypeBits |= ClauseMapFlags::attach_always;
2130
2131 if (mapTypeMod == "attach_never")
2132 mapTypeBits |= ClauseMapFlags::attach_never;
2133
2134 if (mapTypeMod == "attach_auto")
2135 mapTypeBits |= ClauseMapFlags::attach_auto;
2136
2137 if (mapTypeMod == "ref_ptr")
2138 mapTypeBits |= ClauseMapFlags::ref_ptr;
2139
2140 if (mapTypeMod == "ref_ptee")
2141 mapTypeBits |= ClauseMapFlags::ref_ptee;
2142
2143 if (mapTypeMod == "is_device_ptr")
2144 mapTypeBits |= ClauseMapFlags::is_device_ptr;
2145
2146 return success();
2147 };
2148
2149 if (parser.parseCommaSeparatedList(parseTypeAndMod))
2150 return failure();
2151
2152 mapType =
2153 parser.getBuilder().getAttr<mlir::omp::ClauseMapFlagsAttr>(mapTypeBits);
2154
2155 return success();
2156}
2157
2158/// Prints a map_entries map type from its numeric value out into its string
2159/// format.
2160static void printMapClause(OpAsmPrinter &p, Operation *op,
2161 ClauseMapFlagsAttr mapType) {
2163 ClauseMapFlags mapFlags = mapType.getValue();
2164
2165 // handling of always, close, present placed at the beginning of the string
2166 // to aid readability
2167 if (mapTypeToBool(mapFlags, ClauseMapFlags::always))
2168 mapTypeStrs.push_back("always");
2169 if (mapTypeToBool(mapFlags, ClauseMapFlags::implicit))
2170 mapTypeStrs.push_back("implicit");
2171 if (mapTypeToBool(mapFlags, ClauseMapFlags::ompx_hold))
2172 mapTypeStrs.push_back("ompx_hold");
2173 if (mapTypeToBool(mapFlags, ClauseMapFlags::close))
2174 mapTypeStrs.push_back("close");
2175 if (mapTypeToBool(mapFlags, ClauseMapFlags::present))
2176 mapTypeStrs.push_back("present");
2177
2178 // special handling of to/from/tofrom/delete and release/alloc, release +
2179 // alloc are the abscense of one of the other flags, whereas tofrom requires
2180 // both the to and from flag to be set.
2181 bool to = mapTypeToBool(mapFlags, ClauseMapFlags::to);
2182 bool from = mapTypeToBool(mapFlags, ClauseMapFlags::from);
2183
2184 if (to && from)
2185 mapTypeStrs.push_back("tofrom");
2186 else if (from)
2187 mapTypeStrs.push_back("from");
2188 else if (to)
2189 mapTypeStrs.push_back("to");
2190
2191 if (mapTypeToBool(mapFlags, ClauseMapFlags::del))
2192 mapTypeStrs.push_back("delete");
2193 if (mapTypeToBool(mapFlags, ClauseMapFlags::return_param))
2194 mapTypeStrs.push_back("return_param");
2195 if (mapTypeToBool(mapFlags, ClauseMapFlags::storage))
2196 mapTypeStrs.push_back("storage");
2197 if (mapTypeToBool(mapFlags, ClauseMapFlags::priv))
2198 mapTypeStrs.push_back("private");
2199 if (mapTypeToBool(mapFlags, ClauseMapFlags::literal))
2200 mapTypeStrs.push_back("literal");
2201 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach))
2202 mapTypeStrs.push_back("attach");
2203 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_always))
2204 mapTypeStrs.push_back("attach_always");
2205 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_never))
2206 mapTypeStrs.push_back("attach_never");
2207 if (mapTypeToBool(mapFlags, ClauseMapFlags::attach_auto))
2208 mapTypeStrs.push_back("attach_auto");
2209 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptr))
2210 mapTypeStrs.push_back("ref_ptr");
2211 if (mapTypeToBool(mapFlags, ClauseMapFlags::ref_ptee))
2212 mapTypeStrs.push_back("ref_ptee");
2213 if (mapTypeToBool(mapFlags, ClauseMapFlags::is_device_ptr))
2214 mapTypeStrs.push_back("is_device_ptr");
2215 if (mapFlags == ClauseMapFlags::none)
2216 mapTypeStrs.push_back("none");
2217
2218 for (unsigned int i = 0; i < mapTypeStrs.size(); ++i) {
2219 p << mapTypeStrs[i];
2220 if (i + 1 < mapTypeStrs.size()) {
2221 p << ", ";
2222 }
2223 }
2224}
2225
2226static ParseResult parseMembersIndex(OpAsmParser &parser,
2227 ArrayAttr &membersIdx) {
2228 SmallVector<Attribute> values, memberIdxs;
2229
2230 auto parseIndices = [&]() -> ParseResult {
2231 int64_t value;
2232 if (parser.parseInteger(value))
2233 return failure();
2234 values.push_back(IntegerAttr::get(parser.getBuilder().getIntegerType(64),
2235 APInt(64, value, /*isSigned=*/false)));
2236 return success();
2237 };
2238
2239 do {
2240 if (failed(parser.parseLSquare()))
2241 return failure();
2242
2243 if (parser.parseCommaSeparatedList(parseIndices))
2244 return failure();
2245
2246 if (failed(parser.parseRSquare()))
2247 return failure();
2248
2249 memberIdxs.push_back(ArrayAttr::get(parser.getContext(), values));
2250 values.clear();
2251 } while (succeeded(parser.parseOptionalComma()));
2252
2253 if (!memberIdxs.empty())
2254 membersIdx = ArrayAttr::get(parser.getContext(), memberIdxs);
2255
2256 return success();
2257}
2258
2259static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op,
2260 ArrayAttr membersIdx) {
2261 if (!membersIdx)
2262 return;
2263
2264 llvm::interleaveComma(membersIdx, p, [&p](Attribute v) {
2265 p << "[";
2266 auto memberIdx = cast<ArrayAttr>(v);
2267 llvm::interleaveComma(memberIdx.getValue(), p, [&p](Attribute v2) {
2268 p << cast<IntegerAttr>(v2).getInt();
2269 });
2270 p << "]";
2271 });
2272}
2273
2275 VariableCaptureKindAttr mapCaptureType) {
2276 std::string typeCapStr;
2277 llvm::raw_string_ostream typeCap(typeCapStr);
2278 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByRef)
2279 typeCap << "ByRef";
2280 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::ByCopy)
2281 typeCap << "ByCopy";
2282 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::VLAType)
2283 typeCap << "VLAType";
2284 if (mapCaptureType.getValue() == mlir::omp::VariableCaptureKind::This)
2285 typeCap << "This";
2286 p << typeCapStr;
2287}
2288
2289static ParseResult parseCaptureType(OpAsmParser &parser,
2290 VariableCaptureKindAttr &mapCaptureType) {
2291 StringRef mapCaptureKey;
2292 if (parser.parseKeyword(&mapCaptureKey))
2293 return failure();
2294
2295 if (mapCaptureKey == "This")
2296 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2297 parser.getContext(), mlir::omp::VariableCaptureKind::This);
2298 if (mapCaptureKey == "ByRef")
2299 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2300 parser.getContext(), mlir::omp::VariableCaptureKind::ByRef);
2301 if (mapCaptureKey == "ByCopy")
2302 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2303 parser.getContext(), mlir::omp::VariableCaptureKind::ByCopy);
2304 if (mapCaptureKey == "VLAType")
2305 mapCaptureType = mlir::omp::VariableCaptureKindAttr::get(
2306 parser.getContext(), mlir::omp::VariableCaptureKind::VLAType);
2307
2308 return success();
2309}
2310
2311static LogicalResult verifyMapInfoForMapClause(
2312 Operation *op, mlir::omp::MapInfoOp mapInfoOp,
2315 &updateFromVars) {
2316 mlir::omp::ClauseMapFlags mapTypeBits = mapInfoOp.getMapType();
2317
2318 bool to = mapTypeToBool(mapTypeBits, ClauseMapFlags::to);
2319 bool from = mapTypeToBool(mapTypeBits, ClauseMapFlags::from);
2320 bool del = mapTypeToBool(mapTypeBits, ClauseMapFlags::del);
2321
2322 bool always = mapTypeToBool(mapTypeBits, ClauseMapFlags::always);
2323 bool close = mapTypeToBool(mapTypeBits, ClauseMapFlags::close);
2324 bool implicit = mapTypeToBool(mapTypeBits, ClauseMapFlags::implicit);
2325 bool attach = mapTypeToBool(mapTypeBits, ClauseMapFlags::attach);
2326
2327 if ((isa<TargetDataOp>(op) || isa<TargetOp>(op)) && del)
2328 return emitError(op->getLoc(),
2329 "to, from, tofrom and alloc map types are permitted");
2330
2331 if (isa<TargetEnterDataOp>(op) && (from || del))
2332 return emitError(op->getLoc(), "to and alloc map types are permitted");
2333
2334 if (isa<TargetExitDataOp>(op) && to)
2335 return emitError(op->getLoc(),
2336 "from, release and delete map types are permitted");
2337
2338 if (isa<TargetUpdateOp>(op)) {
2339 if (del) {
2340 return emitError(op->getLoc(),
2341 "at least one of to or from map types must be "
2342 "specified, other map types are not permitted");
2343 }
2344
2345 if (!to && !from && !attach) {
2346 return emitError(op->getLoc(),
2347 "at least one of to or from or attach map types must be "
2348 "specified, other map types are not permitted");
2349 }
2350
2351 auto updateVar = mapInfoOp.getVarPtr();
2352
2353 if ((to && from) || (to && updateFromVars.contains(updateVar)) ||
2354 (from && updateToVars.contains(updateVar))) {
2355 return emitError(
2356 op->getLoc(),
2357 "either to or from map types can be specified, not both");
2358 }
2359
2360 if (always || close || implicit) {
2361 return emitError(
2362 op->getLoc(),
2363 "present, mapper and iterator map type modifiers are permitted");
2364 }
2365
2366 // It's possible we have an attach map, in which case if there is no to
2367 // or from tied to it, we skip insertion.
2368 if (to || from) {
2369 to ? updateToVars.insert(updateVar) : updateFromVars.insert(updateVar);
2370 }
2371 }
2372
2373 if ((mapInfoOp.getVarPtrPtr() && !mapInfoOp.getVarPtrPtrType()) ||
2374 (!mapInfoOp.getVarPtrPtr() && mapInfoOp.getVarPtrPtrType())) {
2375 return emitError(op->getLoc(),
2376 "if varPtrPtr or varPtrPtrType is specified, then both "
2377 "must be present");
2378 }
2379
2380 return success();
2381}
2382
2383static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars,
2384 OperandRange mapIterated) {
2387
2388 for (auto mapOp : mapVars) {
2389 if (!mapOp.getDefiningOp())
2390 return emitError(op->getLoc(), "missing map operation");
2391
2392 if (auto mapInfoOp = mapOp.getDefiningOp<mlir::omp::MapInfoOp>()) {
2393 if (failed(verifyMapInfoForMapClause(op, mapInfoOp, updateToVars,
2394 updateFromVars)))
2395 return failure();
2396 } else if (!isa<DeclareMapperInfoOp>(op)) {
2397 return emitError(op->getLoc(),
2398 "map argument is not a map entry operation");
2399 }
2400 }
2401
2402 // Verify iterated map entries.
2403 for (auto iterVal : mapIterated) {
2404 auto iterOp = iterVal.getDefiningOp<mlir::omp::IteratorOp>();
2405 if (!iterOp)
2406 return op->emitOpError() << "'map_iterated' arguments must be defined by "
2407 "'omp.iterator' ops";
2408
2409 // Check that the iterator body yields a value defined by omp.map.info.
2410 auto yieldOp =
2411 cast<mlir::omp::YieldOp>(iterOp.getRegion().front().getTerminator());
2412 auto yieldedMapInfo =
2413 yieldOp.getResults()[0].getDefiningOp<mlir::omp::MapInfoOp>();
2414 if (!yieldedMapInfo)
2415 return op->emitOpError() << "'map_iterated' iterator body must yield "
2416 "a value defined by 'omp.map.info'";
2417
2418 if (failed(verifyMapInfoForMapClause(op, yieldedMapInfo, updateToVars,
2419 updateFromVars)))
2420 return failure();
2421 }
2422
2423 return success();
2424}
2425
2426template <typename OpType>
2427static LogicalResult verifyPrivateVarList(OpType &op);
2428
2429static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp) {
2430 std::optional<DenseI64ArrayAttr> privateMapIndices =
2431 targetOp.getPrivateMapsAttr();
2432
2433 // None of the private operands are mapped.
2434 if (!privateMapIndices.has_value() || !privateMapIndices.value())
2435 return success();
2436
2437 OperandRange privateVars = targetOp.getPrivateVars();
2438
2439 if (privateMapIndices.value().size() !=
2440 static_cast<int64_t>(privateVars.size()))
2441 return emitError(targetOp.getLoc(), "sizes of `private` operand range and "
2442 "`private_maps` attribute mismatch");
2443
2444 return success();
2445}
2446
2447//===----------------------------------------------------------------------===//
2448// MapInfoOp
2449//===----------------------------------------------------------------------===//
2450
2451static LogicalResult verifyMapInfoDefinedArgs(Operation *op,
2452 StringRef clauseName,
2453 OperandRange vars) {
2454 for (Value var : vars)
2455 if (!llvm::isa_and_present<MapInfoOp>(var.getDefiningOp()))
2456 return op->emitOpError()
2457 << "'" << clauseName
2458 << "' arguments must be defined by 'omp.map.info' ops";
2459 return success();
2460}
2461
2462LogicalResult MapInfoOp::verify() {
2463 if (getMapperId() &&
2465 *this, getMapperIdAttr())) {
2466 return emitError("invalid mapper id");
2467 }
2468
2469 if (failed(verifyMapInfoDefinedArgs(*this, "members", getMembers())))
2470 return failure();
2471
2472 return success();
2473}
2474
2475//===----------------------------------------------------------------------===//
2476// TargetDataOp
2477//===----------------------------------------------------------------------===//
2478
2479void TargetDataOp::build(OpBuilder &builder, OperationState &state,
2480 const TargetDataOperands &clauses) {
2481 TargetDataOp::build(builder, state, clauses.device, clauses.ifExpr,
2482 clauses.mapVars, clauses.mapIterated,
2483 clauses.useDeviceAddrVars, clauses.useDevicePtrVars);
2484}
2485
2486LogicalResult TargetDataOp::verify() {
2487 if (getMapVars().empty() && getMapIterated().empty() &&
2488 getUseDevicePtrVars().empty() && getUseDeviceAddrVars().empty()) {
2489 return ::emitError(this->getLoc(),
2490 "At least one of map, use_device_ptr_vars, or "
2491 "use_device_addr_vars operand must be present");
2492 }
2493
2494 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_ptr",
2495 getUseDevicePtrVars())))
2496 return failure();
2497
2498 if (failed(verifyMapInfoDefinedArgs(*this, "use_device_addr",
2499 getUseDeviceAddrVars())))
2500 return failure();
2501
2502 return verifyMapClause(*this, getMapVars(), getMapIterated());
2503}
2504
2505//===----------------------------------------------------------------------===//
2506// TargetEnterDataOp
2507//===----------------------------------------------------------------------===//
2508
2509void TargetEnterDataOp::build(
2510 OpBuilder &builder, OperationState &state,
2511 const TargetEnterExitUpdateDataOperands &clauses) {
2512 MLIRContext *ctx = builder.getContext();
2513 TargetEnterDataOp::build(
2514 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2515 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2516 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2517 clauses.mapIterated, clauses.nowait);
2518}
2519
2520LogicalResult TargetEnterDataOp::verify() {
2521 LogicalResult verifyDependVars =
2522 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2523 getDependIteratedKinds(), getDependIterated());
2524 return failed(verifyDependVars)
2525 ? verifyDependVars
2526 : verifyMapClause(*this, getMapVars(), getMapIterated());
2527}
2528
2529//===----------------------------------------------------------------------===//
2530// TargetExitDataOp
2531//===----------------------------------------------------------------------===//
2532
2533void TargetExitDataOp::build(OpBuilder &builder, OperationState &state,
2534 const TargetEnterExitUpdateDataOperands &clauses) {
2535 MLIRContext *ctx = builder.getContext();
2536 TargetExitDataOp::build(
2537 builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2538 clauses.dependVars, makeArrayAttr(ctx, clauses.dependIteratedKinds),
2539 clauses.dependIterated, clauses.device, clauses.ifExpr, clauses.mapVars,
2540 clauses.mapIterated, clauses.nowait);
2541}
2542
2543LogicalResult TargetExitDataOp::verify() {
2544 LogicalResult verifyDependVars =
2545 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2546 getDependIteratedKinds(), getDependIterated());
2547 return failed(verifyDependVars)
2548 ? verifyDependVars
2549 : verifyMapClause(*this, getMapVars(), getMapIterated());
2550}
2551
2552//===----------------------------------------------------------------------===//
2553// TargetUpdateOp
2554//===----------------------------------------------------------------------===//
2555
2556void TargetUpdateOp::build(OpBuilder &builder, OperationState &state,
2557 const TargetEnterExitUpdateDataOperands &clauses) {
2558 MLIRContext *ctx = builder.getContext();
2559 TargetUpdateOp::build(builder, state, makeArrayAttr(ctx, clauses.dependKinds),
2560 clauses.dependVars,
2561 makeArrayAttr(ctx, clauses.dependIteratedKinds),
2562 clauses.dependIterated, clauses.device, clauses.ifExpr,
2563 clauses.mapVars, clauses.mapIterated, clauses.nowait);
2564}
2565
2566LogicalResult TargetUpdateOp::verify() {
2567 LogicalResult verifyDependVars =
2568 verifyDependVarList(*this, getDependKinds(), getDependVars(),
2569 getDependIteratedKinds(), getDependIterated());
2570 return failed(verifyDependVars)
2571 ? verifyDependVars
2572 : verifyMapClause(*this, getMapVars(), getMapIterated());
2573}
2574
2575//===----------------------------------------------------------------------===//
2576// TargetOp
2577//===----------------------------------------------------------------------===//
2578
2579void TargetOp::build(OpBuilder &builder, OperationState &state,
2580 const TargetOperands &clauses) {
2581 MLIRContext *ctx = builder.getContext();
2582 // TODO Store clauses in op: allocateVars, allocatorVars, inReductionVars,
2583 // inReductionByref, inReductionSyms.
2584 TargetOp::build(
2585 builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.bare,
2586 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
2587 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
2588 clauses.device, clauses.dynGroupprivateAccessGroup,
2589 clauses.dynGroupprivateFallback, clauses.dynGroupprivateSize,
2590 clauses.hasDeviceAddrVars, clauses.hostEvalVars, clauses.ifExpr,
2591 /*in_reduction_vars=*/{}, /*in_reduction_byref=*/nullptr,
2592 /*in_reduction_syms=*/nullptr, clauses.isDevicePtrVars, clauses.mapVars,
2593 clauses.mapIterated, clauses.nowait, clauses.privateVars,
2594 makeArrayAttr(ctx, clauses.privateSyms), clauses.privateNeedsBarrier,
2595 clauses.threadLimitVars,
2596 /*private_maps=*/nullptr);
2597}
2598
2599LogicalResult TargetOp::verify() {
2600 if (failed(verifyDependVarList(*this, getDependKinds(), getDependVars(),
2601 getDependIteratedKinds(),
2602 getDependIterated())))
2603 return failure();
2604
2605 if (failed(verifyMapInfoDefinedArgs(*this, "has_device_addr",
2606 getHasDeviceAddrVars())))
2607 return failure();
2608
2609 if (failed(verifyMapClause(*this, getMapVars(), getMapIterated())))
2610 return failure();
2611
2613 *this, getDynGroupprivateAccessGroupAttr(),
2614 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
2615 return failure();
2616
2617 if (failed(verifyPrivateVarList(*this)))
2618 return failure();
2619
2620 return verifyPrivateVarsMapping(*this);
2621}
2622
2623LogicalResult TargetOp::verifyRegions() {
2624 auto teamsOps = getOps<TeamsOp>();
2625 if (std::distance(teamsOps.begin(), teamsOps.end()) > 1)
2626 return emitError("target containing multiple 'omp.teams' nested ops");
2627
2628 // Check that host_eval values are only used in legal ways.
2629 bool hostEvalTripCount;
2630 Operation *capturedOp = getInnermostCapturedOmpOp();
2631 TargetExecMode execMode = getKernelExecFlags(capturedOp, &hostEvalTripCount);
2632 for (Value hostEvalArg :
2633 cast<BlockArgOpenMPOpInterface>(getOperation()).getHostEvalBlockArgs()) {
2634 for (Operation *user : hostEvalArg.getUsers()) {
2635 if (auto teamsOp = dyn_cast<TeamsOp>(user)) {
2636 // Check if used in num_teams_lower or any of num_teams_upper_vars
2637 if (hostEvalArg == teamsOp.getNumTeamsLower() ||
2638 llvm::is_contained(teamsOp.getNumTeamsUpperVars(), hostEvalArg) ||
2639 llvm::is_contained(teamsOp.getThreadLimitVars(), hostEvalArg))
2640 continue;
2641
2642 return emitOpError() << "host_eval argument only legal as 'num_teams' "
2643 "and 'thread_limit' in 'omp.teams'";
2644 }
2645 if (auto parallelOp = dyn_cast<ParallelOp>(user)) {
2646 if (execMode == TargetExecMode::spmd &&
2647 parallelOp->isAncestor(capturedOp) &&
2648 llvm::is_contained(parallelOp.getNumThreadsVars(), hostEvalArg))
2649 continue;
2650
2651 return emitOpError()
2652 << "host_eval argument only legal as 'num_threads' in "
2653 "'omp.parallel' when representing target SPMD";
2654 }
2655 if (auto loopNestOp = dyn_cast<LoopNestOp>(user)) {
2656 if (hostEvalTripCount && loopNestOp.getOperation() == capturedOp &&
2657 (llvm::is_contained(loopNestOp.getLoopLowerBounds(), hostEvalArg) ||
2658 llvm::is_contained(loopNestOp.getLoopUpperBounds(), hostEvalArg) ||
2659 llvm::is_contained(loopNestOp.getLoopSteps(), hostEvalArg)))
2660 continue;
2661
2662 return emitOpError() << "host_eval argument only legal as loop bounds "
2663 "and steps in 'omp.loop_nest' when trip count "
2664 "must be evaluated in the host";
2665 }
2666
2667 return emitOpError() << "host_eval argument illegal use in '"
2668 << user->getName() << "' operation";
2669 }
2670 }
2671 return success();
2672}
2673
2674static Operation *
2675findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec,
2676 llvm::function_ref<bool(Operation *)> siblingAllowedFn) {
2677 assert(rootOp && "expected valid operation");
2678
2679 Dialect *ompDialect = rootOp->getDialect();
2680 Operation *capturedOp = nullptr;
2681 DominanceInfo domInfo;
2682
2683 // Process in pre-order to check operations from outermost to innermost,
2684 // ensuring we only enter the region of an operation if it meets the criteria
2685 // for being captured. We stop the exploration of nested operations as soon as
2686 // we process a region holding no operations to be captured.
2687 rootOp->walk<WalkOrder::PreOrder>([&](Operation *op) {
2688 if (op == rootOp)
2689 return WalkResult::advance();
2690
2691 // Ignore operations of other dialects or omp operations with no regions,
2692 // because these will only be checked if they are siblings of an omp
2693 // operation that can potentially be captured.
2694 bool isOmpDialect = op->getDialect() == ompDialect;
2695 bool hasRegions = op->getNumRegions() > 0;
2696 if (!isOmpDialect || !hasRegions)
2697 return WalkResult::skip();
2698
2699 // This operation cannot be captured if it can be executed more than once
2700 // (i.e. its block's successors can reach it) or if it's not guaranteed to
2701 // be executed before all exits of the region (i.e. it doesn't dominate all
2702 // blocks with no successors reachable from the entry block).
2703 if (checkSingleMandatoryExec) {
2704 Region *parentRegion = op->getParentRegion();
2705 Block *parentBlock = op->getBlock();
2706
2707 for (Block *successor : parentBlock->getSuccessors())
2708 if (successor->isReachable(parentBlock))
2709 return WalkResult::interrupt();
2710
2711 for (Block &block : *parentRegion)
2712 if (domInfo.isReachableFromEntry(&block) && block.hasNoSuccessors() &&
2713 !domInfo.dominates(parentBlock, &block))
2714 return WalkResult::interrupt();
2715 }
2716
2717 // Don't capture this op if it has a not-allowed sibling, and stop recursing
2718 // into nested operations.
2719 for (Operation &sibling : op->getParentRegion()->getOps())
2720 if (&sibling != op && !siblingAllowedFn(&sibling))
2721 return WalkResult::interrupt();
2722
2723 // Don't continue capturing nested operations if we reach an omp.loop_nest.
2724 // Otherwise, process the contents of this operation.
2725 capturedOp = op;
2726 return llvm::isa<LoopNestOp>(op) ? WalkResult::interrupt()
2728 });
2729
2730 return capturedOp;
2731}
2732
2733Operation *TargetOp::getInnermostCapturedOmpOp() {
2734 auto *ompDialect = getContext()->getLoadedDialect<omp::OpenMPDialect>();
2735
2736 // Only allow OpenMP terminators and non-OpenMP ops that have known memory
2737 // effects, but don't include a memory write effect.
2738 return findCapturedOmpOp(
2739 *this, /*checkSingleMandatoryExec=*/true, [&](Operation *sibling) {
2740 if (!sibling)
2741 return false;
2742
2743 if (ompDialect == sibling->getDialect())
2744 return sibling->hasTrait<OpTrait::IsTerminator>();
2745
2746 if (auto memOp = dyn_cast<MemoryEffectOpInterface>(sibling)) {
2748 effects;
2749 memOp.getEffects(effects);
2750 return !llvm::any_of(
2751 effects, [&](MemoryEffects::EffectInstance &effect) {
2752 return isa<MemoryEffects::Write>(effect.getEffect()) &&
2753 isa<SideEffects::AutomaticAllocationScopeResource>(
2754 effect.getResource());
2755 });
2756 }
2757 return true;
2758 });
2759}
2760
2761/// Check if we can promote SPMD kernel to No-Loop kernel.
2762static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp,
2763 WsloopOp *wsLoopOp) {
2764 // num_teams clause can break no-loop teams/threads assumption.
2765 if (!teamsOp.getNumTeamsUpperVars().empty())
2766 return false;
2767
2768 // Reduction kernels are slower in no-loop mode.
2769 if (teamsOp.getNumReductionVars())
2770 return false;
2771 if (wsLoopOp->getNumReductionVars())
2772 return false;
2773
2774 // Check if the user allows the promotion of kernels to no-loop mode.
2775 OffloadModuleInterface offloadMod =
2776 capturedOp->getParentOfType<omp::OffloadModuleInterface>();
2777 if (!offloadMod)
2778 return false;
2779 auto ompFlags = offloadMod.getFlags();
2780 if (!ompFlags)
2781 return false;
2782 return ompFlags.getAssumeTeamsOversubscription() &&
2783 ompFlags.getAssumeThreadsOversubscription();
2784}
2785
2786TargetExecMode TargetOp::getKernelExecFlags(Operation *capturedOp,
2787 bool *hostEvalTripCount) {
2788 // TODO: Support detection of bare kernel mode.
2789 // A non-null captured op is only valid if it resides inside of a TargetOp
2790 // and is the result of calling getInnermostCapturedOmpOp() on it.
2791 TargetOp targetOp =
2792 capturedOp ? capturedOp->getParentOfType<TargetOp>() : nullptr;
2793 assert((!capturedOp ||
2794 (targetOp && targetOp.getInnermostCapturedOmpOp() == capturedOp)) &&
2795 "unexpected captured op");
2796
2797 if (hostEvalTripCount)
2798 *hostEvalTripCount = false;
2799
2800 // If it's not capturing a loop, it's a default target region.
2801 if (!isa_and_present<LoopNestOp>(capturedOp))
2802 return TargetExecMode::generic;
2803
2804 // Get the innermost non-simd loop wrapper.
2806 cast<LoopNestOp>(capturedOp).gatherWrappers(loopWrappers);
2807 assert(!loopWrappers.empty());
2808
2809 LoopWrapperInterface *innermostWrapper = loopWrappers.begin();
2810 if (isa<SimdOp>(innermostWrapper))
2811 innermostWrapper = std::next(innermostWrapper);
2812
2813 auto numWrappers = std::distance(innermostWrapper, loopWrappers.end());
2814 if (numWrappers != 1 && numWrappers != 2)
2815 return TargetExecMode::generic;
2816
2817 // Detect target-teams-distribute-parallel-wsloop[-simd].
2818 if (numWrappers == 2) {
2819 WsloopOp *wsloopOp = dyn_cast<WsloopOp>(innermostWrapper);
2820 if (!wsloopOp)
2821 return TargetExecMode::generic;
2822
2823 innermostWrapper = std::next(innermostWrapper);
2824 if (!isa<DistributeOp>(innermostWrapper))
2825 return TargetExecMode::generic;
2826
2827 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2828 if (!isa_and_present<ParallelOp>(parallelOp))
2829 return TargetExecMode::generic;
2830
2831 TeamsOp teamsOp = dyn_cast<TeamsOp>(parallelOp->getParentOp());
2832 if (!teamsOp)
2833 return TargetExecMode::generic;
2834
2835 if (teamsOp->getParentOp() == targetOp.getOperation()) {
2836 TargetExecMode result = TargetExecMode::spmd;
2837 if (canPromoteToNoLoop(capturedOp, teamsOp, wsloopOp))
2838 result = TargetExecMode::no_loop;
2839 if (hostEvalTripCount)
2840 *hostEvalTripCount = true;
2841 return result;
2842 }
2843 }
2844 // Detect target-teams-distribute[-simd] and target-teams-loop.
2845 else if (isa<DistributeOp, LoopOp>(innermostWrapper)) {
2846 Operation *teamsOp = (*innermostWrapper)->getParentOp();
2847 if (!isa_and_present<TeamsOp>(teamsOp))
2848 return TargetExecMode::generic;
2849
2850 if (teamsOp->getParentOp() != targetOp.getOperation())
2851 return TargetExecMode::generic;
2852
2853 if (hostEvalTripCount)
2854 *hostEvalTripCount = true;
2855
2856 if (isa<LoopOp>(innermostWrapper))
2857 return TargetExecMode::spmd;
2858
2859 return TargetExecMode::generic;
2860 }
2861 // Detect target-parallel-wsloop[-simd].
2862 else if (isa<WsloopOp>(innermostWrapper)) {
2863 Operation *parallelOp = (*innermostWrapper)->getParentOp();
2864 if (!isa_and_present<ParallelOp>(parallelOp))
2865 return TargetExecMode::generic;
2866
2867 if (parallelOp->getParentOp() == targetOp.getOperation())
2868 return TargetExecMode::spmd;
2869 }
2870
2871 return TargetExecMode::generic;
2872}
2873
2874//===----------------------------------------------------------------------===//
2875// ParallelOp
2876//===----------------------------------------------------------------------===//
2877
2878void ParallelOp::build(OpBuilder &builder, OperationState &state,
2879 ArrayRef<NamedAttribute> attributes) {
2880 ParallelOp::build(builder, state, /*allocate_vars=*/ValueRange(),
2881 /*allocator_vars=*/ValueRange(), /*if_expr=*/nullptr,
2882 /*num_threads_vars=*/ValueRange(),
2883 /*private_vars=*/ValueRange(),
2884 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
2885 /*proc_bind_kind=*/nullptr,
2886 /*reduction_mod =*/nullptr, /*reduction_vars=*/ValueRange(),
2887 /*reduction_byref=*/nullptr, /*reduction_syms=*/nullptr);
2888 state.addAttributes(attributes);
2889}
2890
2891void ParallelOp::build(OpBuilder &builder, OperationState &state,
2892 const ParallelOperands &clauses) {
2893 MLIRContext *ctx = builder.getContext();
2894 ParallelOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
2895 clauses.ifExpr, clauses.numThreadsVars, clauses.privateVars,
2896 makeArrayAttr(ctx, clauses.privateSyms),
2897 clauses.privateNeedsBarrier, clauses.procBindKind,
2898 clauses.reductionMod, clauses.reductionVars,
2899 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
2900 makeArrayAttr(ctx, clauses.reductionSyms));
2901}
2902
2903template <typename OpType>
2904static LogicalResult verifyPrivateVarList(OpType &op) {
2905 auto privateVars = op.getPrivateVars();
2906 auto privateSyms = op.getPrivateSymsAttr();
2907
2908 if (privateVars.empty() && (privateSyms == nullptr || privateSyms.empty()))
2909 return success();
2910
2911 auto numPrivateVars = privateVars.size();
2912 auto numPrivateSyms = (privateSyms == nullptr) ? 0 : privateSyms.size();
2913
2914 if (numPrivateVars != numPrivateSyms)
2915 return op.emitError() << "inconsistent number of private variables and "
2916 "privatizer op symbols, private vars: "
2917 << numPrivateVars
2918 << " vs. privatizer op symbols: " << numPrivateSyms;
2919
2920 for (auto privateVarInfo : llvm::zip_equal(privateVars, privateSyms)) {
2921 Type varType = std::get<0>(privateVarInfo).getType();
2922 SymbolRefAttr privateSym = cast<SymbolRefAttr>(std::get<1>(privateVarInfo));
2923 PrivateClauseOp privatizerOp =
2925
2926 if (privatizerOp == nullptr)
2927 return op.emitError() << "failed to lookup privatizer op with symbol: '"
2928 << privateSym << "'";
2929
2930 Type privatizerType = privatizerOp.getArgType();
2931
2932 if (privatizerType && (varType != privatizerType))
2933 return op.emitError()
2934 << "type mismatch between a "
2935 << (privatizerOp.getDataSharingType() ==
2936 DataSharingClauseType::Private
2937 ? "private"
2938 : "firstprivate")
2939 << " variable and its privatizer op, var type: " << varType
2940 << " vs. privatizer op type: " << privatizerType;
2941 }
2942
2943 return success();
2944}
2945
2946LogicalResult ParallelOp::verify() {
2947 if (getAllocateVars().size() != getAllocatorVars().size())
2948 return emitError(
2949 "expected equal sizes for allocate and allocator variables");
2950
2951 if (failed(verifyPrivateVarList(*this)))
2952 return failure();
2953
2954 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
2955 getReductionByref());
2956}
2957
2958LogicalResult ParallelOp::verifyRegions() {
2959 auto distChildOps = getOps<DistributeOp>();
2960 int numDistChildOps = std::distance(distChildOps.begin(), distChildOps.end());
2961 if (numDistChildOps > 1)
2962 return emitError()
2963 << "multiple 'omp.distribute' nested inside of 'omp.parallel'";
2964
2965 if (numDistChildOps == 1) {
2966 if (!isComposite())
2967 return emitError()
2968 << "'omp.composite' attribute missing from composite operation";
2969
2970 auto *ompDialect = getContext()->getLoadedDialect<OpenMPDialect>();
2971 Operation &distributeOp = **distChildOps.begin();
2972 for (Operation &childOp : getOps()) {
2973 if (&childOp == &distributeOp || ompDialect != childOp.getDialect())
2974 continue;
2975
2976 if (!childOp.hasTrait<OpTrait::IsTerminator>())
2977 return emitError() << "unexpected OpenMP operation inside of composite "
2978 "'omp.parallel': "
2979 << childOp.getName();
2980 }
2981 } else if (isComposite()) {
2982 return emitError()
2983 << "'omp.composite' attribute present in non-composite operation";
2984 }
2985 return success();
2986}
2987
2988//===----------------------------------------------------------------------===//
2989// TeamsOp
2990//===----------------------------------------------------------------------===//
2991
2993 while ((op = op->getParentOp()))
2994 if (isa<OpenMPDialect>(op->getDialect()))
2995 return false;
2996 return true;
2997}
2998
2999void TeamsOp::build(OpBuilder &builder, OperationState &state,
3000 const TeamsOperands &clauses) {
3001 MLIRContext *ctx = builder.getContext();
3002 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
3003 TeamsOp::build(
3004 builder, state, clauses.allocateVars, clauses.allocatorVars,
3005 clauses.dynGroupprivateAccessGroup, clauses.dynGroupprivateFallback,
3006 clauses.dynGroupprivateSize, clauses.ifExpr, clauses.numTeamsLower,
3007 clauses.numTeamsUpperVars, /*private_vars=*/{}, /*private_syms=*/nullptr,
3008 /*private_needs_barrier=*/nullptr, clauses.reductionMod,
3009 clauses.reductionVars,
3010 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3011 makeArrayAttr(ctx, clauses.reductionSyms), clauses.threadLimitVars);
3012}
3013
3014// Verify num_teams clause
3015static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower,
3016 OperandRange numTeamsUpperVars) {
3017 // If lower is specified, upper must have exactly one value
3018 if (numTeamsLower) {
3019 if (numTeamsUpperVars.size() != 1)
3020 return op->emitError(
3021 "expected exactly one num_teams upper bound when lower bound is "
3022 "specified");
3023 if (numTeamsLower.getType() != numTeamsUpperVars[0].getType())
3024 return op->emitError(
3025 "expected num_teams upper bound and lower bound to be "
3026 "the same type");
3027 }
3028
3029 return success();
3030}
3031
3032LogicalResult TeamsOp::verify() {
3033 // Check parent region
3034 // TODO If nested inside of a target region, also check that it does not
3035 // contain any statements, declarations or directives other than this
3036 // omp.teams construct. The issue is how to support the initialization of
3037 // this operation's own arguments (allow SSA values across omp.target?).
3038 Operation *op = getOperation();
3039 if (!isa<TargetOp>(op->getParentOp()) &&
3041 return emitError("expected to be nested inside of omp.target or not nested "
3042 "in any OpenMP dialect operations");
3043
3044 // Check for num_teams clause restrictions
3045 if (failed(verifyNumTeamsClause(op, this->getNumTeamsLower(),
3046 this->getNumTeamsUpperVars())))
3047 return failure();
3048
3049 // Check for allocate clause restrictions
3050 if (getAllocateVars().size() != getAllocatorVars().size())
3051 return emitError(
3052 "expected equal sizes for allocate and allocator variables");
3053
3055 op, getDynGroupprivateAccessGroupAttr(),
3056 getDynGroupprivateFallbackAttr(), getDynGroupprivateSize())))
3057 return failure();
3058
3059 if (failed(verifyPrivateVarList(*this)))
3060 return failure();
3061
3062 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3063 getReductionByref());
3064}
3065
3066//===----------------------------------------------------------------------===//
3067// SectionOp
3068//===----------------------------------------------------------------------===//
3069
3070OperandRange SectionOp::getPrivateVars() {
3071 return getParentOp().getPrivateVars();
3072}
3073
3074OperandRange SectionOp::getReductionVars() {
3075 return getParentOp().getReductionVars();
3076}
3077
3078//===----------------------------------------------------------------------===//
3079// SectionsOp
3080//===----------------------------------------------------------------------===//
3081
3082void SectionsOp::build(OpBuilder &builder, OperationState &state,
3083 const SectionsOperands &clauses) {
3084 MLIRContext *ctx = builder.getContext();
3085 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
3086 SectionsOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3087 clauses.nowait, /*private_vars=*/{},
3088 /*private_syms=*/nullptr, /*private_needs_barrier=*/nullptr,
3089 clauses.reductionMod, clauses.reductionVars,
3090 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3091 makeArrayAttr(ctx, clauses.reductionSyms));
3092}
3093
3094LogicalResult SectionsOp::verify() {
3095 if (getAllocateVars().size() != getAllocatorVars().size())
3096 return emitError(
3097 "expected equal sizes for allocate and allocator variables");
3098
3099 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3100 getReductionByref());
3101}
3102
3103LogicalResult SectionsOp::verifyRegions() {
3104 for (auto &inst : *getRegion().begin()) {
3105 if (!(isa<SectionOp>(inst) || isa<TerminatorOp>(inst))) {
3106 return emitOpError()
3107 << "expected omp.section op or terminator op inside region";
3108 }
3109 }
3110
3111 return success();
3112}
3113
3114//===----------------------------------------------------------------------===//
3115// ScopeOp
3116//===----------------------------------------------------------------------===//
3117
3118void ScopeOp::build(OpBuilder &builder, OperationState &state,
3119 const ScopeOperands &clauses) {
3120 MLIRContext *ctx = builder.getContext();
3121 ScopeOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3122 clauses.nowait, clauses.privateVars,
3123 makeArrayAttr(ctx, clauses.privateSyms),
3124 clauses.privateNeedsBarrier, clauses.reductionMod,
3125 clauses.reductionVars,
3126 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3127 makeArrayAttr(ctx, clauses.reductionSyms));
3128}
3129
3130LogicalResult ScopeOp::verify() {
3131 if (getAllocateVars().size() != getAllocatorVars().size())
3132 return emitError(
3133 "expected equal sizes for allocate and allocator variables");
3134
3135 if (failed(verifyPrivateVarList(*this)))
3136 return failure();
3137
3138 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3139 getReductionByref());
3140}
3141
3142//===----------------------------------------------------------------------===//
3143// SingleOp
3144//===----------------------------------------------------------------------===//
3145
3146void SingleOp::build(OpBuilder &builder, OperationState &state,
3147 const SingleOperands &clauses) {
3148 MLIRContext *ctx = builder.getContext();
3149 // TODO Store clauses in op: privateVars, privateSyms, privateNeedsBarrier
3150 SingleOp::build(builder, state, clauses.allocateVars, clauses.allocatorVars,
3151 clauses.copyprivateVars,
3152 makeArrayAttr(ctx, clauses.copyprivateSyms), clauses.nowait,
3153 /*private_vars=*/{}, /*private_syms=*/nullptr,
3154 /*private_needs_barrier=*/nullptr);
3155}
3156
3157LogicalResult SingleOp::verify() {
3158 // Check for allocate clause restrictions
3159 if (getAllocateVars().size() != getAllocatorVars().size())
3160 return emitError(
3161 "expected equal sizes for allocate and allocator variables");
3162
3163 return verifyCopyprivateVarList(*this, getCopyprivateVars(),
3164 getCopyprivateSyms());
3165}
3166
3167//===----------------------------------------------------------------------===//
3168// WorkshareOp
3169//===----------------------------------------------------------------------===//
3170
3171void WorkshareOp::build(OpBuilder &builder, OperationState &state,
3172 const WorkshareOperands &clauses) {
3173 WorkshareOp::build(builder, state, clauses.nowait);
3174}
3175
3176//===----------------------------------------------------------------------===//
3177// WorkshareLoopWrapperOp
3178//===----------------------------------------------------------------------===//
3179
3180LogicalResult WorkshareLoopWrapperOp::verifyRegions() {
3181 if (isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3182 getNestedWrapper())
3183 return emitOpError() << "expected to be a standalone loop wrapper";
3184
3185 return success();
3186}
3187
3188//===----------------------------------------------------------------------===//
3189// LoopWrapperInterface
3190//===----------------------------------------------------------------------===//
3191
3192LogicalResult LoopWrapperInterface::verifyImpl() {
3193 Operation *op = this->getOperation();
3194 if (!op->hasTrait<OpTrait::NoTerminator>() ||
3196 return emitOpError() << "loop wrapper must also have the `NoTerminator` "
3197 "and `SingleBlock` traits";
3198
3199 if (op->getNumRegions() != 1)
3200 return emitOpError() << "loop wrapper does not contain exactly one region";
3201
3202 Region &region = op->getRegion(0);
3203 if (range_size(region.getOps()) != 1)
3204 return emitOpError()
3205 << "loop wrapper does not contain exactly one nested op";
3206
3207 Operation &firstOp = *region.op_begin();
3208 if (!isa<LoopNestOp, LoopWrapperInterface>(firstOp))
3209 return emitOpError() << "nested in loop wrapper is not another loop "
3210 "wrapper or `omp.loop_nest`";
3211
3212 return success();
3213}
3214
3215//===----------------------------------------------------------------------===//
3216// LoopOp
3217//===----------------------------------------------------------------------===//
3218
3219void LoopOp::build(OpBuilder &builder, OperationState &state,
3220 const LoopOperands &clauses) {
3221 MLIRContext *ctx = builder.getContext();
3222
3223 LoopOp::build(builder, state, clauses.bindKind, clauses.privateVars,
3224 makeArrayAttr(ctx, clauses.privateSyms),
3225 clauses.privateNeedsBarrier, clauses.order, clauses.orderMod,
3226 clauses.reductionMod, clauses.reductionVars,
3227 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3228 makeArrayAttr(ctx, clauses.reductionSyms));
3229}
3230
3231LogicalResult LoopOp::verify() {
3232 if (failed(verifyPrivateVarList(*this)))
3233 return failure();
3234
3235 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3236 getReductionByref());
3237}
3238
3239LogicalResult LoopOp::verifyRegions() {
3240 if (llvm::isa_and_nonnull<LoopWrapperInterface>((*this)->getParentOp()) ||
3241 getNestedWrapper())
3242 return emitOpError() << "expected to be a standalone loop wrapper";
3243
3244 return success();
3245}
3246
3247//===----------------------------------------------------------------------===//
3248// WsloopOp
3249//===----------------------------------------------------------------------===//
3250
3251void WsloopOp::build(OpBuilder &builder, OperationState &state,
3252 ArrayRef<NamedAttribute> attributes) {
3253 build(builder, state, /*allocate_vars=*/{}, /*allocator_vars=*/{},
3254 /*linear_vars=*/ValueRange(), /*linear_step_vars=*/ValueRange(),
3255 /*linear_var_types*/ nullptr, /*linear_modifiers=*/nullptr,
3256 /*nowait=*/false, /*order=*/nullptr, /*order_mod=*/nullptr,
3257 /*ordered=*/nullptr, /*private_vars=*/{}, /*private_syms=*/nullptr,
3258 /*private_needs_barrier=*/false,
3259 /*reduction_mod=*/nullptr, /*reduction_vars=*/ValueRange(),
3260 /*reduction_byref=*/nullptr,
3261 /*reduction_syms=*/nullptr, /*schedule_kind=*/nullptr,
3262 /*schedule_chunk=*/nullptr, /*schedule_mod=*/nullptr,
3263 /*schedule_simd=*/false);
3264 state.addAttributes(attributes);
3265}
3266
3267void WsloopOp::build(OpBuilder &builder, OperationState &state,
3268 const WsloopOperands &clauses) {
3269 MLIRContext *ctx = builder.getContext();
3270 // TODO: Store clauses in op: allocateVars, allocatorVars
3271 WsloopOp::build(
3272 builder, state,
3273 /*allocate_vars=*/{}, /*allocator_vars=*/{}, clauses.linearVars,
3274 clauses.linearStepVars, clauses.linearVarTypes, clauses.linearModifiers,
3275 clauses.nowait, clauses.order, clauses.orderMod, clauses.ordered,
3276 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3277 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3278 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3279 makeArrayAttr(ctx, clauses.reductionSyms), clauses.scheduleKind,
3280 clauses.scheduleChunk, clauses.scheduleMod, clauses.scheduleSimd);
3281}
3282
3283LogicalResult WsloopOp::verify() {
3284 if (failed(
3285 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3286 return failure();
3287 if (getLinearVars().size() &&
3288 getLinearVarTypes().value().size() != getLinearVars().size())
3289 return emitError() << "Ill-formed type attributes for linear variables";
3290
3291 if (failed(verifyPrivateVarList(*this)))
3292 return failure();
3293
3294 return verifyReductionVarList(*this, getReductionSyms(), getReductionVars(),
3295 getReductionByref());
3296}
3297
3298LogicalResult WsloopOp::verifyRegions() {
3299 bool isCompositeChildLeaf =
3300 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3301
3302 if (LoopWrapperInterface nested = getNestedWrapper()) {
3303 if (!isComposite())
3304 return emitError()
3305 << "'omp.composite' attribute missing from composite wrapper";
3306
3307 // Check for the allowed leaf constructs that may appear in a composite
3308 // construct directly after DO/FOR.
3309 if (!isa<SimdOp>(nested))
3310 return emitError() << "only supported nested wrapper is 'omp.simd'";
3311
3312 } else if (isComposite() && !isCompositeChildLeaf) {
3313 return emitError()
3314 << "'omp.composite' attribute present in non-composite wrapper";
3315 } else if (!isComposite() && isCompositeChildLeaf) {
3316 return emitError()
3317 << "'omp.composite' attribute missing from composite wrapper";
3318 }
3319
3320 return success();
3321}
3322
3323//===----------------------------------------------------------------------===//
3324// Simd construct [2.9.3.1]
3325//===----------------------------------------------------------------------===//
3326
3327void SimdOp::build(OpBuilder &builder, OperationState &state,
3328 const SimdOperands &clauses) {
3329 MLIRContext *ctx = builder.getContext();
3330 SimdOp::build(builder, state, clauses.alignedVars,
3331 makeArrayAttr(ctx, clauses.alignments), clauses.ifExpr,
3332 clauses.linearVars, clauses.linearStepVars,
3333 clauses.linearVarTypes, clauses.linearModifiers,
3334 clauses.nontemporalVars, clauses.order, clauses.orderMod,
3335 clauses.privateVars, makeArrayAttr(ctx, clauses.privateSyms),
3336 clauses.privateNeedsBarrier, clauses.reductionMod,
3337 clauses.reductionVars,
3338 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3339 makeArrayAttr(ctx, clauses.reductionSyms), clauses.safelen,
3340 clauses.simdlen);
3341}
3342
3343LogicalResult SimdOp::verify() {
3344 if (getSimdlen().has_value() && getSafelen().has_value() &&
3345 getSimdlen().value() > getSafelen().value())
3346 return emitOpError()
3347 << "simdlen clause and safelen clause are both present, but the "
3348 "simdlen value is not less than or equal to safelen value";
3349
3350 if (verifyAlignedClause(*this, getAlignments(), getAlignedVars()).failed())
3351 return failure();
3352
3353 if (verifyNontemporalClause(*this, getNontemporalVars()).failed())
3354 return failure();
3355
3356 if (failed(
3357 verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars())))
3358 return failure();
3359
3360 bool isCompositeChildLeaf =
3361 llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp());
3362
3363 if (!isComposite() && isCompositeChildLeaf)
3364 return emitError()
3365 << "'omp.composite' attribute missing from composite wrapper";
3366
3367 if (isComposite() && !isCompositeChildLeaf)
3368 return emitError()
3369 << "'omp.composite' attribute present in non-composite wrapper";
3370
3371 // Firstprivate is not allowed for SIMD in the standard. Check that none of
3372 // the private decls are for firstprivate.
3373 std::optional<ArrayAttr> privateSyms = getPrivateSyms();
3374 if (privateSyms) {
3375 for (const Attribute &sym : *privateSyms) {
3376 auto symRef = cast<SymbolRefAttr>(sym);
3377 omp::PrivateClauseOp privatizer =
3379 getOperation(), symRef);
3380 if (!privatizer)
3381 return emitError() << "Cannot find privatizer '" << symRef << "'";
3382 if (privatizer.getDataSharingType() ==
3383 DataSharingClauseType::FirstPrivate)
3384 return emitError() << "FIRSTPRIVATE cannot be used with SIMD";
3385 }
3386 }
3387
3388 if (failed(verifyPrivateVarList(*this)))
3389 return failure();
3390
3391 if (getLinearVars().size() &&
3392 getLinearVarTypes().value().size() != getLinearVars().size())
3393 return emitError() << "Ill-formed type attributes for linear variables";
3394 return success();
3395}
3396
3397LogicalResult SimdOp::verifyRegions() {
3398 if (getNestedWrapper())
3399 return emitOpError() << "must wrap an 'omp.loop_nest' directly";
3400
3401 return success();
3402}
3403
3404//===----------------------------------------------------------------------===//
3405// Distribute construct [2.9.4.1]
3406//===----------------------------------------------------------------------===//
3407
3408void DistributeOp::build(OpBuilder &builder, OperationState &state,
3409 const DistributeOperands &clauses) {
3410 DistributeOp::build(builder, state, clauses.allocateVars,
3411 clauses.allocatorVars, clauses.distScheduleStatic,
3412 clauses.distScheduleChunkSize, clauses.order,
3413 clauses.orderMod, clauses.privateVars,
3414 makeArrayAttr(builder.getContext(), clauses.privateSyms),
3415 clauses.privateNeedsBarrier);
3416}
3417
3418LogicalResult DistributeOp::verify() {
3419 if (this->getDistScheduleChunkSize() && !this->getDistScheduleStatic())
3420 return emitOpError() << "chunk size set without "
3421 "dist_schedule_static being present";
3422
3423 if (getAllocateVars().size() != getAllocatorVars().size())
3424 return emitError(
3425 "expected equal sizes for allocate and allocator variables");
3426
3427 if (failed(verifyPrivateVarList(*this)))
3428 return failure();
3429
3430 return success();
3431}
3432
3433LogicalResult DistributeOp::verifyRegions() {
3434 if (LoopWrapperInterface nested = getNestedWrapper()) {
3435 if (!isComposite())
3436 return emitError()
3437 << "'omp.composite' attribute missing from composite wrapper";
3438 // Check for the allowed leaf constructs that may appear in a composite
3439 // construct directly after DISTRIBUTE.
3440 if (isa<WsloopOp>(nested)) {
3441 Operation *parentOp = (*this)->getParentOp();
3442 if (!llvm::dyn_cast_if_present<ParallelOp>(parentOp) ||
3443 !cast<ComposableOpInterface>(parentOp).isComposite()) {
3444 return emitError() << "an 'omp.wsloop' nested wrapper is only allowed "
3445 "when a composite 'omp.parallel' is the direct "
3446 "parent";
3447 }
3448 } else if (!isa<SimdOp>(nested))
3449 return emitError() << "only supported nested wrappers are 'omp.simd' and "
3450 "'omp.wsloop'";
3451 } else if (isComposite()) {
3452 return emitError()
3453 << "'omp.composite' attribute present in non-composite wrapper";
3454 }
3455
3456 return success();
3457}
3458
3459//===----------------------------------------------------------------------===//
3460// DeclareMapperOp / DeclareMapperInfoOp
3461//===----------------------------------------------------------------------===//
3462
3463void DeclareMapperInfoOp::build(OpBuilder &builder, OperationState &state,
3464 const DeclareMapperInfoOperands &clauses) {
3465 DeclareMapperInfoOp::build(builder, state, clauses.mapVars,
3466 clauses.mapIterated);
3467}
3468
3469LogicalResult DeclareMapperInfoOp::verify() {
3470 return verifyMapClause(*this, getMapVars(), getMapIterated());
3471}
3472
3473LogicalResult DeclareMapperOp::verifyRegions() {
3474 if (!llvm::isa_and_present<DeclareMapperInfoOp>(
3475 getRegion().getBlocks().front().getTerminator()))
3476 return emitOpError() << "expected terminator to be a DeclareMapperInfoOp";
3477
3478 return success();
3479}
3480
3481//===----------------------------------------------------------------------===//
3482// DeclareReductionOp
3483//===----------------------------------------------------------------------===//
3484
3485LogicalResult DeclareReductionOp::verifyRegions() {
3486 if (!getAllocRegion().empty()) {
3487 for (YieldOp yieldOp : getAllocRegion().getOps<YieldOp>()) {
3488 if (yieldOp.getResults().size() != 1 ||
3489 yieldOp.getResults().getTypes()[0] != getType())
3490 return emitOpError() << "expects alloc region to yield a value "
3491 "of the reduction type";
3492 }
3493 }
3494
3495 if (getInitializerRegion().empty())
3496 return emitOpError() << "expects non-empty initializer region";
3497 Block &initializerEntryBlock = getInitializerRegion().front();
3498
3499 if (initializerEntryBlock.getNumArguments() == 1) {
3500 if (!getAllocRegion().empty())
3501 return emitOpError() << "expects two arguments to the initializer region "
3502 "when an allocation region is used";
3503 } else if (initializerEntryBlock.getNumArguments() == 2) {
3504 if (getAllocRegion().empty())
3505 return emitOpError() << "expects one argument to the initializer region "
3506 "when no allocation region is used";
3507 } else {
3508 return emitOpError()
3509 << "expects one or two arguments to the initializer region";
3510 }
3511
3512 for (mlir::Value arg : initializerEntryBlock.getArguments())
3513 if (arg.getType() != getType())
3514 return emitOpError() << "expects initializer region argument to match "
3515 "the reduction type";
3516
3517 for (YieldOp yieldOp : getInitializerRegion().getOps<YieldOp>()) {
3518 if (yieldOp.getResults().size() != 1 ||
3519 yieldOp.getResults().getTypes()[0] != getType())
3520 return emitOpError() << "expects initializer region to yield a value "
3521 "of the reduction type";
3522 }
3523
3524 if (getReductionRegion().empty())
3525 return emitOpError() << "expects non-empty reduction region";
3526 Block &reductionEntryBlock = getReductionRegion().front();
3527 if (reductionEntryBlock.getNumArguments() != 2 ||
3528 reductionEntryBlock.getArgumentTypes()[0] !=
3529 reductionEntryBlock.getArgumentTypes()[1] ||
3530 reductionEntryBlock.getArgumentTypes()[0] != getType())
3531 return emitOpError() << "expects reduction region with two arguments of "
3532 "the reduction type";
3533 for (YieldOp yieldOp : getReductionRegion().getOps<YieldOp>()) {
3534 if (yieldOp.getResults().size() != 1 ||
3535 yieldOp.getResults().getTypes()[0] != getType())
3536 return emitOpError() << "expects reduction region to yield a value "
3537 "of the reduction type";
3538 }
3539
3540 if (!getAtomicReductionRegion().empty()) {
3541 Block &atomicReductionEntryBlock = getAtomicReductionRegion().front();
3542 if (atomicReductionEntryBlock.getNumArguments() != 2 ||
3543 atomicReductionEntryBlock.getArgumentTypes()[0] !=
3544 atomicReductionEntryBlock.getArgumentTypes()[1])
3545 return emitOpError() << "expects atomic reduction region with two "
3546 "arguments of the same type";
3547 auto ptrType = llvm::dyn_cast<PointerLikeType>(
3548 atomicReductionEntryBlock.getArgumentTypes()[0]);
3549 if (!ptrType ||
3550 (ptrType.getElementType() && ptrType.getElementType() != getType()))
3551 return emitOpError() << "expects atomic reduction region arguments to "
3552 "be accumulators containing the reduction type";
3553 }
3554
3555 if (getCleanupRegion().empty())
3556 return success();
3557 Block &cleanupEntryBlock = getCleanupRegion().front();
3558 if (cleanupEntryBlock.getNumArguments() != 1 ||
3559 cleanupEntryBlock.getArgument(0).getType() != getType())
3560 return emitOpError() << "expects cleanup region with one argument "
3561 "of the reduction type";
3562
3563 return success();
3564}
3565
3566//===----------------------------------------------------------------------===//
3567// TaskOp
3568//===----------------------------------------------------------------------===//
3569
3570void TaskOp::build(OpBuilder &builder, OperationState &state,
3571 const TaskOperands &clauses) {
3572 MLIRContext *ctx = builder.getContext();
3573 TaskOp::build(
3574 builder, state, clauses.iterated, clauses.affinityVars,
3575 clauses.allocateVars, clauses.allocatorVars,
3576 makeArrayAttr(ctx, clauses.dependKinds), clauses.dependVars,
3577 makeArrayAttr(ctx, clauses.dependIteratedKinds), clauses.dependIterated,
3578 clauses.final, clauses.ifExpr, clauses.inReductionVars,
3579 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3580 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3581 clauses.priority, /*private_vars=*/clauses.privateVars,
3582 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3583 clauses.privateNeedsBarrier, clauses.untied, clauses.eventHandle);
3584}
3585
3586LogicalResult TaskOp::verify() {
3587 LogicalResult verifyDependVars =
3588 verifyDependVarList(*this, getDependKinds(), getDependVars(),
3589 getDependIteratedKinds(), getDependIterated());
3590 if (failed(verifyDependVars))
3591 return verifyDependVars;
3592
3593 if (failed(verifyPrivateVarList(*this)))
3594 return failure();
3595
3596 return verifyReductionVarList(*this, getInReductionSyms(),
3597 getInReductionVars(), getInReductionByref());
3598}
3599
3600//===----------------------------------------------------------------------===//
3601// TaskgroupOp
3602//===----------------------------------------------------------------------===//
3603
3604void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
3605 const TaskgroupOperands &clauses) {
3606 MLIRContext *ctx = builder.getContext();
3607 TaskgroupOp::build(builder, state, clauses.allocateVars,
3608 clauses.allocatorVars, clauses.taskReductionVars,
3609 makeDenseBoolArrayAttr(ctx, clauses.taskReductionByref),
3610 makeArrayAttr(ctx, clauses.taskReductionSyms));
3611}
3612
3613LogicalResult TaskgroupOp::verify() {
3614 return verifyReductionVarList(*this, getTaskReductionSyms(),
3615 getTaskReductionVars(),
3616 getTaskReductionByref());
3617}
3618
3619//===----------------------------------------------------------------------===//
3620// TaskloopContextOp
3621//===----------------------------------------------------------------------===//
3622
3623void TaskloopContextOp::build(OpBuilder &builder, OperationState &state,
3624 const TaskloopContextOperands &clauses) {
3625 MLIRContext *ctx = builder.getContext();
3626 TaskloopContextOp::build(
3627 builder, state, clauses.allocateVars, clauses.allocatorVars,
3628 clauses.final, clauses.grainsizeMod, clauses.grainsize, clauses.ifExpr,
3629 clauses.inReductionVars,
3630 makeDenseBoolArrayAttr(ctx, clauses.inReductionByref),
3631 makeArrayAttr(ctx, clauses.inReductionSyms), clauses.mergeable,
3632 clauses.nogroup, clauses.numTasksMod, clauses.numTasks, clauses.priority,
3633 /*private_vars=*/clauses.privateVars,
3634 /*private_syms=*/makeArrayAttr(ctx, clauses.privateSyms),
3635 clauses.privateNeedsBarrier, clauses.reductionMod, clauses.reductionVars,
3636 makeDenseBoolArrayAttr(ctx, clauses.reductionByref),
3637 makeArrayAttr(ctx, clauses.reductionSyms), clauses.untied);
3638}
3639
3640TaskloopWrapperOp TaskloopContextOp::getLoopOp() {
3641 return cast<TaskloopWrapperOp>(
3642 *llvm::find_if(getRegion().front(), [](mlir::Operation &op) {
3643 return isa<TaskloopWrapperOp>(op);
3644 }));
3645}
3646
3647LogicalResult TaskloopContextOp::verify() {
3648 if (getAllocateVars().size() != getAllocatorVars().size())
3649 return emitError(
3650 "expected equal sizes for allocate and allocator variables");
3651
3652 if (failed(verifyPrivateVarList(*this)))
3653 return failure();
3654
3655 if (failed(verifyReductionVarList(*this, getReductionSyms(),
3656 getReductionVars(), getReductionByref())) ||
3657 failed(verifyReductionVarList(*this, getInReductionSyms(),
3658 getInReductionVars(),
3659 getInReductionByref())))
3660 return failure();
3661
3662 if (!getReductionVars().empty() && getNogroup())
3663 return emitError("if a reduction clause is present on the taskloop "
3664 "directive, the nogroup clause must not be specified");
3665 for (auto var : getReductionVars()) {
3666 if (llvm::is_contained(getInReductionVars(), var))
3667 return emitError("the same list item cannot appear in both a reduction "
3668 "and an in_reduction clause");
3669 }
3670
3671 if (getGrainsize() && getNumTasks()) {
3672 return emitError(
3673 "the grainsize clause and num_tasks clause are mutually exclusive and "
3674 "may not appear on the same taskloop directive");
3675 }
3676
3677 return success();
3678}
3679
3680LogicalResult TaskloopContextOp::verifyRegions() {
3681 Region &region = getRegion();
3682 if (region.empty())
3683 return emitOpError() << "expected non-empty region";
3684
3685 auto count = llvm::count_if(region.front(), [](mlir::Operation &op) {
3686 return isa<TaskloopWrapperOp>(op);
3687 });
3688 if (count != 1)
3689 return emitOpError()
3690 << "expected exactly 1 TaskloopWrapperOp directly nested in "
3691 "the region, but "
3692 << count << " were found";
3693 TaskloopWrapperOp loopWrapperOp = getLoopOp();
3694
3695 auto loopNestOp = dyn_cast<LoopNestOp>(loopWrapperOp.getWrappedLoop());
3696 // This will fail the verifier for TaskloopWrapperOp and print an error
3697 // message there.
3698 if (!loopNestOp)
3699 return failure();
3700
3701 std::function<bool(Value)> isValidBoundValue = [&](Value value) -> bool {
3702 Region *valueRegion = value.getParentRegion();
3703 // A loop bound value defined outside of the taskloop context region is
3704 // valid. A region is considered an ancestor of itself.
3705 if (!region.isAncestor(valueRegion))
3706 return true;
3707
3708 Operation *defOp = value.getDefiningOp();
3709 if (!defOp || defOp->getNumRegions() != 0 || !isPure(defOp))
3710 return false;
3711
3712 return llvm::all_of(defOp->getOperands(), isValidBoundValue);
3713 };
3714 auto hasUnsupportedTaskloopLocalBound = [&](OperandRange range) -> bool {
3715 return llvm::any_of(range,
3716 [&](Value value) { return !isValidBoundValue(value); });
3717 };
3718
3719 if (hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopLowerBounds()) ||
3720 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopUpperBounds()) ||
3721 hasUnsupportedTaskloopLocalBound(loopNestOp.getLoopSteps())) {
3722 return emitOpError()
3723 << "expects loop bounds and steps to be defined outside of the "
3724 "taskloop.context region or by pure, regionless operations "
3725 "that do not depend on block arguments";
3726 }
3727
3728 return success();
3729}
3730
3731//===----------------------------------------------------------------------===//
3732// TaskloopWrapperOp
3733//===----------------------------------------------------------------------===//
3734
3735void TaskloopWrapperOp::build(OpBuilder &builder, OperationState &state,
3736 const TaskloopWrapperOperands &clauses) {
3737 TaskloopWrapperOp::build(builder, state);
3738}
3739
3740TaskloopContextOp TaskloopWrapperOp::getTaskloopContext() {
3741 return dyn_cast<TaskloopContextOp>(getOperation()->getParentOp());
3742}
3743
3744LogicalResult TaskloopWrapperOp::verify() {
3745 TaskloopContextOp context = getTaskloopContext();
3746 if (!context)
3747 return emitOpError() << "expected to be nested in a taskloop context op";
3748 return success();
3749}
3750
3751LogicalResult TaskloopWrapperOp::verifyRegions() {
3752 if (LoopWrapperInterface nested = getNestedWrapper()) {
3753 if (!isComposite())
3754 return emitError()
3755 << "'omp.composite' attribute missing from composite wrapper";
3756
3757 // Check for the allowed leaf constructs that may appear in a composite
3758 // construct directly after TASKLOOP.
3759 if (!isa<SimdOp>(nested))
3760 return emitError() << "only supported nested wrapper is 'omp.simd'";
3761 } else if (isComposite()) {
3762 return emitError()
3763 << "'omp.composite' attribute present in non-composite wrapper";
3764 }
3765
3766 return success();
3767}
3768
3769//===----------------------------------------------------------------------===//
3770// LoopNestOp
3771//===----------------------------------------------------------------------===//
3772
3773ParseResult LoopNestOp::parse(OpAsmParser &parser, OperationState &result) {
3774 // Parse an opening `(` followed by induction variables followed by `)`
3777 Type loopVarType;
3779 parser.parseColonType(loopVarType) ||
3780 // Parse loop bounds.
3781 parser.parseEqual() ||
3782 parser.parseOperandList(lbs, ivs.size(), OpAsmParser::Delimiter::Paren) ||
3783 parser.parseKeyword("to") ||
3784 parser.parseOperandList(ubs, ivs.size(), OpAsmParser::Delimiter::Paren))
3785 return failure();
3786
3787 for (auto &iv : ivs)
3788 iv.type = loopVarType;
3789
3790 auto *ctx = parser.getBuilder().getContext();
3791 // Parse "inclusive" flag.
3792 if (succeeded(parser.parseOptionalKeyword("inclusive")))
3793 result.addAttribute("loop_inclusive", UnitAttr::get(ctx));
3794
3795 // Parse step values.
3797 if (parser.parseKeyword("step") ||
3798 parser.parseOperandList(steps, ivs.size(), OpAsmParser::Delimiter::Paren))
3799 return failure();
3800
3801 // Parse collapse
3802 int64_t value = 0;
3803 if (!parser.parseOptionalKeyword("collapse") &&
3804 (parser.parseLParen() || parser.parseInteger(value) ||
3805 parser.parseRParen()))
3806 return failure();
3807 if (value > 1)
3808 result.addAttribute(
3809 "collapse_num_loops",
3810 IntegerAttr::get(parser.getBuilder().getI64Type(), value));
3811
3812 // Parse tiles
3814 auto parseTiles = [&]() -> ParseResult {
3815 int64_t tile;
3816 if (parser.parseInteger(tile))
3817 return failure();
3818 tiles.push_back(tile);
3819 return success();
3820 };
3821
3822 if (!parser.parseOptionalKeyword("tiles") &&
3823 (parser.parseLParen() || parser.parseCommaSeparatedList(parseTiles) ||
3824 parser.parseRParen()))
3825 return failure();
3826
3827 if (tiles.size() > 0)
3828 result.addAttribute("tile_sizes", DenseI64ArrayAttr::get(ctx, tiles));
3829
3830 // Parse the body.
3831 Region *region = result.addRegion();
3832 if (parser.parseRegion(*region, ivs))
3833 return failure();
3834
3835 // Resolve operands.
3836 if (parser.resolveOperands(lbs, loopVarType, result.operands) ||
3837 parser.resolveOperands(ubs, loopVarType, result.operands) ||
3838 parser.resolveOperands(steps, loopVarType, result.operands))
3839 return failure();
3840
3841 // Parse the optional attribute list.
3842 return parser.parseOptionalAttrDict(result.attributes);
3843}
3844
3845void LoopNestOp::print(OpAsmPrinter &p) {
3846 Region &region = getRegion();
3847 auto args = region.getArguments();
3848 p << " (" << args << ") : " << args[0].getType() << " = ("
3849 << getLoopLowerBounds() << ") to (" << getLoopUpperBounds() << ") ";
3850 if (getLoopInclusive())
3851 p << "inclusive ";
3852 p << "step (" << getLoopSteps() << ") ";
3853 if (int64_t numCollapse = getCollapseNumLoops())
3854 if (numCollapse > 1)
3855 p << "collapse(" << numCollapse << ") ";
3856
3857 if (const auto tiles = getTileSizes())
3858 p << "tiles(" << tiles.value() << ") ";
3859
3860 p.printRegion(region, /*printEntryBlockArgs=*/false);
3861}
3862
3863void LoopNestOp::build(OpBuilder &builder, OperationState &state,
3864 const LoopNestOperands &clauses) {
3865 MLIRContext *ctx = builder.getContext();
3866 LoopNestOp::build(builder, state, clauses.collapseNumLoops,
3867 clauses.loopLowerBounds, clauses.loopUpperBounds,
3868 clauses.loopSteps, clauses.loopInclusive,
3869 makeDenseI64ArrayAttr(ctx, clauses.tileSizes));
3870}
3871
3872LogicalResult LoopNestOp::verify() {
3873 if (getLoopLowerBounds().empty())
3874 return emitOpError() << "must represent at least one loop";
3875
3876 if (getLoopLowerBounds().size() != getIVs().size())
3877 return emitOpError() << "number of range arguments and IVs do not match";
3878
3879 for (auto [lb, iv] : llvm::zip_equal(getLoopLowerBounds(), getIVs())) {
3880 if (lb.getType() != iv.getType())
3881 return emitOpError()
3882 << "range argument type does not match corresponding IV type";
3883 }
3884
3885 uint64_t numIVs = getIVs().size();
3886
3887 if (const auto &numCollapse = getCollapseNumLoops())
3888 if (numCollapse > numIVs)
3889 return emitOpError()
3890 << "collapse value is larger than the number of loops";
3891
3892 if (const auto &tiles = getTileSizes())
3893 if (tiles.value().size() > numIVs)
3894 return emitOpError() << "too few canonical loops for tile dimensions";
3895
3896 if (!llvm::dyn_cast_if_present<LoopWrapperInterface>((*this)->getParentOp()))
3897 return emitOpError() << "expects parent op to be a loop wrapper";
3898
3899 return success();
3900}
3901
3902void LoopNestOp::gatherWrappers(
3904 Operation *parent = (*this)->getParentOp();
3905 while (auto wrapper =
3906 llvm::dyn_cast_if_present<LoopWrapperInterface>(parent)) {
3907 wrappers.push_back(wrapper);
3908 parent = parent->getParentOp();
3909 }
3910}
3911
3912//===----------------------------------------------------------------------===//
3913// OpenMP canonical loop handling
3914//===----------------------------------------------------------------------===//
3915
3916std::tuple<NewCliOp, OpOperand *, OpOperand *>
3917mlir::omp ::decodeCli(Value cli) {
3918
3919 // Defining a CLI for a generated loop is optional; if there is none then
3920 // there is no followup-tranformation
3921 if (!cli)
3922 return {{}, nullptr, nullptr};
3923
3924 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
3925 "Unexpected type of cli");
3926
3927 NewCliOp create = cast<NewCliOp>(cli.getDefiningOp());
3928 OpOperand *gen = nullptr;
3929 OpOperand *cons = nullptr;
3930 for (OpOperand &use : cli.getUses()) {
3931 auto op = cast<LoopTransformationInterface>(use.getOwner());
3932
3933 unsigned opnum = use.getOperandNumber();
3934 if (op.isGeneratee(opnum)) {
3935 assert(!gen && "Each CLI may have at most one def");
3936 gen = &use;
3937 } else if (op.isApplyee(opnum)) {
3938 assert(!cons && "Each CLI may have at most one consumer");
3939 cons = &use;
3940 } else {
3941 llvm_unreachable("Unexpected operand for a CLI");
3942 }
3943 }
3944
3945 return {create, gen, cons};
3946}
3947
3948void NewCliOp::build(::mlir::OpBuilder &odsBuilder,
3949 ::mlir::OperationState &odsState) {
3950 odsState.addTypes(CanonicalLoopInfoType::get(odsBuilder.getContext()));
3951}
3952
3953void NewCliOp::getAsmResultNames(OpAsmSetValueNameFn setNameFn) {
3954 Value result = getResult();
3955 auto [newCli, gen, cons] = decodeCli(result);
3956
3957 // Structured binding `gen` cannot be captured in lambdas before C++20
3958 OpOperand *generator = gen;
3959
3960 // Derive the CLI variable name from its generator:
3961 // * "canonloop" for omp.canonical_loop
3962 // * custom name for loop transformation generatees
3963 // * "cli" as fallback if no generator
3964 // * "_r<idx>" suffix for nested loops, where <idx> is the sequential order
3965 // at that level
3966 // * "_s<idx>" suffix for operations with multiple regions, where <idx> is
3967 // the index of that region
3968 std::string cliName{"cli"};
3969 if (gen) {
3970 cliName =
3972 .Case([&](CanonicalLoopOp op) {
3973 return generateLoopNestingName("canonloop", op);
3974 })
3975 .Case([&](UnrollHeuristicOp op) -> std::string {
3976 llvm_unreachable("heuristic unrolling does not generate a loop");
3977 })
3978 .Case([&](FuseOp op) -> std::string {
3979 unsigned opnum = generator->getOperandNumber();
3980 // The position of the first loop to be fused is the same position
3981 // as the resulting fused loop
3982 if (op.getFirst().has_value() && opnum != op.getFirst().value())
3983 return "canonloop_fuse";
3984 else
3985 return "fused";
3986 })
3987 .Case([&](TileOp op) -> std::string {
3988 auto [generateesFirst, generateesCount] =
3989 op.getGenerateesODSOperandIndexAndLength();
3990 unsigned firstGrid = generateesFirst;
3991 unsigned firstIntratile = generateesFirst + generateesCount / 2;
3992 unsigned end = generateesFirst + generateesCount;
3993 unsigned opnum = generator->getOperandNumber();
3994 // In the OpenMP apply and looprange clauses, indices are 1-based
3995 if (firstGrid <= opnum && opnum < firstIntratile) {
3996 unsigned gridnum = opnum - firstGrid + 1;
3997 return ("grid" + Twine(gridnum)).str();
3998 }
3999 if (firstIntratile <= opnum && opnum < end) {
4000 unsigned intratilenum = opnum - firstIntratile + 1;
4001 return ("intratile" + Twine(intratilenum)).str();
4002 }
4003 llvm_unreachable("Unexpected generatee argument");
4004 })
4005 .DefaultUnreachable("TODO: Custom name for this operation");
4006 }
4007
4008 setNameFn(result, cliName);
4009}
4010
4011LogicalResult NewCliOp::verify() {
4012 Value cli = getResult();
4013
4014 assert(cli.getType() == CanonicalLoopInfoType::get(cli.getContext()) &&
4015 "Unexpected type of cli");
4016
4017 // Check that the CLI is used in at most generator and one consumer
4018 OpOperand *gen = nullptr;
4019 OpOperand *cons = nullptr;
4020 for (mlir::OpOperand &use : cli.getUses()) {
4021 auto op = cast<mlir::omp::LoopTransformationInterface>(use.getOwner());
4022
4023 unsigned opnum = use.getOperandNumber();
4024 if (op.isGeneratee(opnum)) {
4025 if (gen) {
4026 InFlightDiagnostic error =
4027 emitOpError("CLI must have at most one generator");
4028 error.attachNote(gen->getOwner()->getLoc())
4029 .append("first generator here:");
4030 error.attachNote(use.getOwner()->getLoc())
4031 .append("second generator here:");
4032 return error;
4033 }
4034
4035 gen = &use;
4036 } else if (op.isApplyee(opnum)) {
4037 if (cons) {
4038 InFlightDiagnostic error =
4039 emitOpError("CLI must have at most one consumer");
4040 error.attachNote(cons->getOwner()->getLoc())
4041 .append("first consumer here:")
4042 .appendOp(*cons->getOwner(),
4043 OpPrintingFlags().printGenericOpForm());
4044 error.attachNote(use.getOwner()->getLoc())
4045 .append("second consumer here:")
4046 .appendOp(*use.getOwner(), OpPrintingFlags().printGenericOpForm());
4047 return error;
4048 }
4049
4050 cons = &use;
4051 } else {
4052 llvm_unreachable("Unexpected operand for a CLI");
4053 }
4054 }
4055
4056 // If the CLI is source of a transformation, it must have a generator
4057 if (cons && !gen) {
4058 InFlightDiagnostic error = emitOpError("CLI has no generator");
4059 error.attachNote(cons->getOwner()->getLoc())
4060 .append("see consumer here: ")
4061 .appendOp(*cons->getOwner(), OpPrintingFlags().printGenericOpForm());
4062 return error;
4063 }
4064
4065 return success();
4066}
4067
4068void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4069 Value tripCount) {
4070 odsState.addOperands(tripCount);
4071 odsState.addOperands(Value());
4072 (void)odsState.addRegion();
4073}
4074
4075void CanonicalLoopOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4076 Value tripCount, ::mlir::Value cli) {
4077 odsState.addOperands(tripCount);
4078 odsState.addOperands(cli);
4079 (void)odsState.addRegion();
4080}
4081
4082void CanonicalLoopOp::getAsmBlockNames(OpAsmSetBlockNameFn setNameFn) {
4083 setNameFn(&getRegion().front(), "body_entry");
4084}
4085
4086void CanonicalLoopOp::getAsmBlockArgumentNames(Region &region,
4087 OpAsmSetValueNameFn setNameFn) {
4088 std::string ivName = generateLoopNestingName("iv", *this);
4089 setNameFn(region.getArgument(0), ivName);
4090}
4091
4092void CanonicalLoopOp::print(OpAsmPrinter &p) {
4093 if (getCli())
4094 p << '(' << getCli() << ')';
4095 p << ' ' << getInductionVar() << " : " << getInductionVar().getType()
4096 << " in range(" << getTripCount() << ") ";
4097
4098 p.printRegion(getRegion(), /*printEntryBlockArgs=*/false,
4099 /*printBlockTerminators=*/true);
4100
4101 p.printOptionalAttrDict((*this)->getAttrs());
4102}
4103
4104mlir::ParseResult CanonicalLoopOp::parse(::mlir::OpAsmParser &parser,
4106 CanonicalLoopInfoType cliType =
4107 CanonicalLoopInfoType::get(parser.getContext());
4108
4109 // Parse (optional) omp.cli identifier
4111 SmallVector<mlir::Value, 1> cliOperand;
4112 if (!parser.parseOptionalLParen()) {
4113 if (parser.parseOperand(cli) ||
4114 parser.resolveOperand(cli, cliType, cliOperand) || parser.parseRParen())
4115 return failure();
4116 }
4117
4118 // We derive the type of tripCount from inductionVariable. MLIR requires the
4119 // type of tripCount to be known when calling resolveOperand so we have parse
4120 // the type before processing the inductionVariable.
4121 OpAsmParser::Argument inductionVariable;
4123 if (parser.parseArgument(inductionVariable, /*allowType*/ true) ||
4124 parser.parseKeyword("in") || parser.parseKeyword("range") ||
4125 parser.parseLParen() || parser.parseOperand(tripcount) ||
4126 parser.parseRParen() ||
4127 parser.resolveOperand(tripcount, inductionVariable.type, result.operands))
4128 return failure();
4129
4130 // Parse the loop body.
4131 Region *region = result.addRegion();
4132 if (parser.parseRegion(*region, {inductionVariable}))
4133 return failure();
4134
4135 // We parsed the cli operand forst, but because it is optional, it must be
4136 // last in the operand list.
4137 result.operands.append(cliOperand);
4138
4139 // Parse the optional attribute list.
4140 if (parser.parseOptionalAttrDict(result.attributes))
4141 return failure();
4142
4143 return mlir::success();
4144}
4145
4146LogicalResult CanonicalLoopOp::verify() {
4147 // The region's entry must accept the induction variable
4148 // It can also be empty if just created
4149 if (!getRegion().empty()) {
4150 Region &region = getRegion();
4151 if (region.getNumArguments() != 1)
4152 return emitOpError(
4153 "Canonical loop region must have exactly one argument");
4154
4155 if (getInductionVar().getType() != getTripCount().getType())
4156 return emitOpError(
4157 "Region argument must be the same type as the trip count");
4158 }
4159
4160 return success();
4161}
4162
4163Value CanonicalLoopOp::getInductionVar() { return getRegion().getArgument(0); }
4164
4165std::pair<unsigned, unsigned>
4166CanonicalLoopOp::getApplyeesODSOperandIndexAndLength() {
4167 // No applyees
4168 return {0, 0};
4169}
4170
4171std::pair<unsigned, unsigned>
4172CanonicalLoopOp::getGenerateesODSOperandIndexAndLength() {
4173 return getODSOperandIndexAndLength(odsIndex_cli);
4174}
4175
4176//===----------------------------------------------------------------------===//
4177// UnrollHeuristicOp
4178//===----------------------------------------------------------------------===//
4179
4180void UnrollHeuristicOp::build(::mlir::OpBuilder &odsBuilder,
4181 ::mlir::OperationState &odsState,
4182 ::mlir::Value cli) {
4183 odsState.addOperands(cli);
4184}
4185
4186void UnrollHeuristicOp::print(OpAsmPrinter &p) {
4187 p << '(' << getApplyee() << ')';
4188
4189 p.printOptionalAttrDict((*this)->getAttrs());
4190}
4191
4192mlir::ParseResult UnrollHeuristicOp::parse(::mlir::OpAsmParser &parser,
4194 auto cliType = CanonicalLoopInfoType::get(parser.getContext());
4195
4196 if (parser.parseLParen())
4197 return failure();
4198
4200 if (parser.parseOperand(applyee) ||
4201 parser.resolveOperand(applyee, cliType, result.operands))
4202 return failure();
4203
4204 if (parser.parseRParen())
4205 return failure();
4206
4207 // Optional output loop (full unrolling has none)
4208 if (!parser.parseOptionalArrow()) {
4209 if (parser.parseLParen() || parser.parseRParen())
4210 return failure();
4211 }
4212
4213 // Parse the optional attribute list.
4214 if (parser.parseOptionalAttrDict(result.attributes))
4215 return failure();
4216
4217 return mlir::success();
4218}
4219
4220std::pair<unsigned, unsigned>
4221UnrollHeuristicOp ::getApplyeesODSOperandIndexAndLength() {
4222 return getODSOperandIndexAndLength(odsIndex_applyee);
4223}
4224
4225std::pair<unsigned, unsigned>
4226UnrollHeuristicOp::getGenerateesODSOperandIndexAndLength() {
4227 return {0, 0};
4228}
4229
4230//===----------------------------------------------------------------------===//
4231// TileOp
4232//===----------------------------------------------------------------------===//
4233
4234static void printLoopTransformClis(OpAsmPrinter &p, TileOp op,
4235 OperandRange generatees,
4236 OperandRange applyees) {
4237 if (!generatees.empty())
4238 p << '(' << llvm::interleaved(generatees) << ')';
4239
4240 if (!applyees.empty())
4241 p << " <- (" << llvm::interleaved(applyees) << ')';
4242}
4243
4244static ParseResult parseLoopTransformClis(
4245 OpAsmParser &parser,
4248 if (parser.parseOptionalLess()) {
4249 // Syntax 1: generatees present
4250
4251 if (parser.parseOperandList(generateesOperands,
4253 return failure();
4254
4255 if (parser.parseLess())
4256 return failure();
4257 } else {
4258 // Syntax 2: generatees omitted
4259 }
4260
4261 // Parse `<-` (`<` has already been parsed)
4262 if (parser.parseMinus())
4263 return failure();
4264
4265 if (parser.parseOperandList(applyeesOperands,
4267 return failure();
4268
4269 return success();
4270}
4271
4272/// Check properties of the loop nest consisting of the transformation's
4273/// applyees:
4274/// 1. They are nested inside each other
4275/// 2. They are perfectly nested
4276/// (no code with side-effects in-between the loops)
4277/// 3. They are rectangular
4278/// (loop bounds are invariant in respect to the outer loops)
4279///
4280/// TODO: Generalize for LoopTransformationInterface.
4281static LogicalResult checkApplyeesNesting(TileOp op) {
4282 // Collect the loops from the nest
4283 bool isOnlyCanonLoops = true;
4285 for (Value applyee : op.getApplyees()) {
4286 auto [create, gen, cons] = decodeCli(applyee);
4287
4288 if (!gen)
4289 return op.emitOpError() << "applyee CLI has no generator";
4290
4291 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4292 canonLoops.push_back(loop);
4293 if (!loop)
4294 isOnlyCanonLoops = false;
4295 }
4296
4297 // FIXME: We currently can only verify non-rectangularity and perfect nest of
4298 // omp.canonical_loop.
4299 if (!isOnlyCanonLoops)
4300 return success();
4301
4302 DenseSet<Value> parentIVs;
4303 for (auto i : llvm::seq<int>(1, canonLoops.size())) {
4304 auto parentLoop = canonLoops[i - 1];
4305 auto loop = canonLoops[i];
4306
4307 if (parentLoop.getOperation() != loop.getOperation()->getParentOp())
4308 return op.emitOpError()
4309 << "tiled loop nest must be nested within each other";
4310
4311 parentIVs.insert(parentLoop.getInductionVar());
4312
4313 // Canonical loop must be perfectly nested, i.e. the body of the parent must
4314 // only contain the omp.canonical_loop of the nested loops, and
4315 // omp.terminator
4316 bool isPerfectlyNested = [&]() {
4317 auto &parentBody = parentLoop.getRegion();
4318 if (!parentBody.hasOneBlock())
4319 return false;
4320 auto &parentBlock = parentBody.getBlocks().front();
4321
4322 auto nestedLoopIt = parentBlock.begin();
4323 if (nestedLoopIt == parentBlock.end() ||
4324 (&*nestedLoopIt != loop.getOperation()))
4325 return false;
4326
4327 auto termIt = std::next(nestedLoopIt);
4328 if (termIt == parentBlock.end() || !isa<TerminatorOp>(termIt))
4329 return false;
4330
4331 if (std::next(termIt) != parentBlock.end())
4332 return false;
4333
4334 return true;
4335 }();
4336 if (!isPerfectlyNested)
4337 return op.emitOpError() << "tiled loop nest must be perfectly nested";
4338
4339 if (parentIVs.contains(loop.getTripCount()))
4340 return op.emitOpError() << "tiled loop nest must be rectangular";
4341 }
4342
4343 // TODO: The tile sizes must be computed before the loop, but checking this
4344 // requires dominance analysis. For instance:
4345 //
4346 // %canonloop = omp.new_cli
4347 // omp.canonical_loop(%canonloop) %iv : i32 in range(%tc) {
4348 // // write to %x
4349 // omp.terminator
4350 // }
4351 // %ts = llvm.load %x
4352 // omp.tile <- (%canonloop) sizes(%ts : i32)
4353
4354 return success();
4355}
4356
4357LogicalResult TileOp::verify() {
4358 if (getApplyees().empty())
4359 return emitOpError() << "must apply to at least one loop";
4360
4361 if (getSizes().size() != getApplyees().size())
4362 return emitOpError() << "there must be one tile size for each applyee";
4363
4364 if (!getGeneratees().empty() &&
4365 2 * getSizes().size() != getGeneratees().size())
4366 return emitOpError()
4367 << "expecting two times the number of generatees than applyees";
4368
4369 return checkApplyeesNesting(*this);
4370}
4371
4372std::pair<unsigned, unsigned> TileOp ::getApplyeesODSOperandIndexAndLength() {
4373 return getODSOperandIndexAndLength(odsIndex_applyees);
4374}
4375
4376std::pair<unsigned, unsigned> TileOp::getGenerateesODSOperandIndexAndLength() {
4377 return getODSOperandIndexAndLength(odsIndex_generatees);
4378}
4379
4380//===----------------------------------------------------------------------===//
4381// FuseOp
4382//===----------------------------------------------------------------------===//
4383
4384static void printLoopTransformClis(OpAsmPrinter &p, FuseOp op,
4385 OperandRange generatees,
4386 OperandRange applyees) {
4387 if (!generatees.empty())
4388 p << '(' << llvm::interleaved(generatees) << ')';
4389
4390 if (!applyees.empty())
4391 p << " <- (" << llvm::interleaved(applyees) << ')';
4392}
4393
4394LogicalResult FuseOp::verify() {
4395 if (getApplyees().size() < 2)
4396 return emitOpError() << "must apply to at least two loops";
4397
4398 if (getFirst().has_value() && getCount().has_value()) {
4399 int64_t first = getFirst().value();
4400 int64_t count = getCount().value();
4401 if ((unsigned)(first + count - 1) > getApplyees().size())
4402 return emitOpError() << "the numbers of applyees must be at least first "
4403 "minus one plus count attributes";
4404 if (!getGeneratees().empty() &&
4405 getGeneratees().size() != getApplyees().size() + 1 - count)
4406 return emitOpError() << "the number of generatees must be the number of "
4407 "aplyees plus one minus count";
4408
4409 } else {
4410 if (!getGeneratees().empty() && getGeneratees().size() != 1)
4411 return emitOpError()
4412 << "in a complete fuse the number of generatees must be exactly 1";
4413 }
4414 for (auto &&applyee : getApplyees()) {
4415 auto [create, gen, cons] = decodeCli(applyee);
4416
4417 if (!gen)
4418 return emitOpError() << "applyee CLI has no generator";
4419 auto loop = dyn_cast_or_null<CanonicalLoopOp>(gen->getOwner());
4420 if (!loop)
4421 return emitOpError()
4422 << "currently only supports omp.canonical_loop as applyee";
4423 }
4424 return success();
4425}
4426std::pair<unsigned, unsigned> FuseOp::getApplyeesODSOperandIndexAndLength() {
4427 return getODSOperandIndexAndLength(odsIndex_applyees);
4428}
4429
4430std::pair<unsigned, unsigned> FuseOp::getGenerateesODSOperandIndexAndLength() {
4431 return getODSOperandIndexAndLength(odsIndex_generatees);
4432}
4433
4434//===----------------------------------------------------------------------===//
4435// Critical construct (2.17.1)
4436//===----------------------------------------------------------------------===//
4437
4438void CriticalDeclareOp::build(OpBuilder &builder, OperationState &state,
4439 const CriticalDeclareOperands &clauses) {
4440 CriticalDeclareOp::build(builder, state, clauses.symName, clauses.hint);
4441}
4442
4443LogicalResult CriticalDeclareOp::verify() {
4444 return verifySynchronizationHint(*this, getHint());
4445}
4446
4447LogicalResult CriticalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
4448 if (getNameAttr()) {
4449 SymbolRefAttr symbolRef = getNameAttr();
4450 auto decl = symbolTable.lookupNearestSymbolFrom<CriticalDeclareOp>(
4451 *this, symbolRef);
4452 if (!decl) {
4453 return emitOpError() << "expected symbol reference " << symbolRef
4454 << " to point to a critical declaration";
4455 }
4456 }
4457
4458 return success();
4459}
4460
4461//===----------------------------------------------------------------------===//
4462// Ordered construct
4463//===----------------------------------------------------------------------===//
4464
4465static LogicalResult verifyOrderedParent(Operation &op) {
4466 bool hasRegion = op.getNumRegions() > 0;
4467 auto loopOp = op.getParentOfType<LoopNestOp>();
4468 if (!loopOp) {
4469 if (hasRegion)
4470 return success();
4471
4472 // TODO: Consider if this needs to be the case only for the standalone
4473 // variant of the ordered construct.
4474 return op.emitOpError() << "must be nested inside of a loop";
4475 }
4476
4477 Operation *wrapper = loopOp->getParentOp();
4478 if (auto wsloopOp = dyn_cast<WsloopOp>(wrapper)) {
4479 IntegerAttr orderedAttr = wsloopOp.getOrderedAttr();
4480 if (!orderedAttr)
4481 return op.emitOpError() << "the enclosing worksharing-loop region must "
4482 "have an ordered clause";
4483
4484 if (hasRegion && orderedAttr.getInt() != 0)
4485 return op.emitOpError() << "the enclosing loop's ordered clause must not "
4486 "have a parameter present";
4487
4488 if (!hasRegion && orderedAttr.getInt() == 0)
4489 return op.emitOpError() << "the enclosing loop's ordered clause must "
4490 "have a parameter present";
4491 } else if (!isa<SimdOp>(wrapper)) {
4492 return op.emitOpError() << "must be nested inside of a worksharing, simd "
4493 "or worksharing simd loop";
4494 }
4495 return success();
4496}
4497
4498void OrderedOp::build(OpBuilder &builder, OperationState &state,
4499 const OrderedOperands &clauses) {
4500 OrderedOp::build(builder, state, clauses.doacrossDependType,
4501 clauses.doacrossNumLoops, clauses.doacrossDependVars);
4502}
4503
4504LogicalResult OrderedOp::verify() {
4505 if (failed(verifyOrderedParent(**this)))
4506 return failure();
4507
4508 auto wrapper = (*this)->getParentOfType<WsloopOp>();
4509 if (!wrapper || *wrapper.getOrdered() != *getDoacrossNumLoops())
4510 return emitOpError() << "number of variables in depend clause does not "
4511 << "match number of iteration variables in the "
4512 << "doacross loop";
4513
4514 return success();
4515}
4516
4517void OrderedRegionOp::build(OpBuilder &builder, OperationState &state,
4518 const OrderedRegionOperands &clauses) {
4519 OrderedRegionOp::build(builder, state, clauses.parLevelSimd);
4520}
4521
4522LogicalResult OrderedRegionOp::verify() { return verifyOrderedParent(**this); }
4523
4524//===----------------------------------------------------------------------===//
4525// TaskwaitOp
4526//===----------------------------------------------------------------------===//
4527
4528void TaskwaitOp::build(OpBuilder &builder, OperationState &state,
4529 const TaskwaitOperands &clauses) {
4530 // TODO Store clauses in op: dependKinds, dependVars, nowait.
4531 TaskwaitOp::build(builder, state, /*depend_kinds=*/nullptr,
4532 /*depend_vars=*/{}, /*depend_iterated_kinds=*/nullptr,
4533 /*depend_iterated=*/{}, /*nowait=*/nullptr);
4534}
4535
4536//===----------------------------------------------------------------------===//
4537// Verifier for AtomicReadOp
4538//===----------------------------------------------------------------------===//
4539
4540LogicalResult AtomicReadOp::verify() {
4541 if (verifyCommon().failed())
4542 return mlir::failure();
4543
4544 int64_t version = 50;
4545 if (auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4546 if (Attribute verAttr = moduleOp->getAttr("omp.version"))
4547 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4548
4549 if (auto mo = getMemoryOrder()) {
4550 if (*mo == ClauseMemoryOrderKind::Release) {
4551 return emitError("memory-order must not be release for atomic reads");
4552 }
4553 if (*mo == ClauseMemoryOrderKind::Acq_rel) {
4554 // acq_rel is prohibited on read only in OpenMP 5.0; allowed in 5.1+.
4555 if (version < 51)
4556 return emitError("memory-order must not be acq_rel for atomic reads");
4557 }
4558 }
4559 return verifySynchronizationHint(*this, getHint());
4560}
4561
4562//===----------------------------------------------------------------------===//
4563// Verifier for AtomicWriteOp
4564//===----------------------------------------------------------------------===//
4565
4566LogicalResult AtomicWriteOp::verify() {
4567 if (verifyCommon().failed())
4568 return mlir::failure();
4569
4570 int64_t version = 50;
4571 if (auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4572 if (Attribute verAttr = moduleOp->getAttr("omp.version"))
4573 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4574
4575 if (auto mo = getMemoryOrder()) {
4576 if (*mo == ClauseMemoryOrderKind::Acquire) {
4577 return emitError("memory-order must not be acquire for atomic writes");
4578 }
4579 if (*mo == ClauseMemoryOrderKind::Acq_rel) {
4580 // acq_rel is prohibited on write only in OpenMP 5.0; allowed in 5.1+.
4581 if (version < 51)
4582 return emitError("memory-order must not be acq_rel for atomic writes");
4583 }
4584 }
4585 return verifySynchronizationHint(*this, getHint());
4586}
4587
4588//===----------------------------------------------------------------------===//
4589// Verifier for AtomicUpdateOp
4590//===----------------------------------------------------------------------===//
4591
4592LogicalResult AtomicUpdateOp::canonicalize(AtomicUpdateOp op,
4593 PatternRewriter &rewriter) {
4594 if (op.isNoOp()) {
4595 rewriter.eraseOp(op);
4596 return success();
4597 }
4598 if (Value writeVal = op.getWriteOpVal()) {
4599 rewriter.replaceOpWithNewOp<AtomicWriteOp>(
4600 op, op.getX(), writeVal, op.getHintAttr(), op.getMemoryOrderAttr());
4601 return success();
4602 }
4603 return failure();
4604}
4605
4606LogicalResult AtomicUpdateOp::verify() {
4607 if (verifyCommon().failed())
4608 return mlir::failure();
4609
4610 int64_t version = 50;
4611 if (auto moduleOp = getOperation()->getParentOfType<ModuleOp>())
4612 if (Attribute verAttr = moduleOp->getAttr("omp.version"))
4613 version = llvm::cast<VersionAttr>(verAttr).getVersion();
4614
4615 if (auto mo = getMemoryOrder()) {
4616 if (*mo == ClauseMemoryOrderKind::Acq_rel ||
4617 *mo == ClauseMemoryOrderKind::Acquire) {
4618 // This restriction applies only to OpenMP 5.0; removed in 5.1.
4619 if (version < 51)
4620 return emitError(
4621 "memory-order must not be acq_rel or acquire for atomic updates");
4622 }
4623 }
4624
4625 return verifySynchronizationHint(*this, getHint());
4626}
4627
4628LogicalResult AtomicUpdateOp::verifyRegions() { return verifyRegionsCommon(); }
4629
4630//===----------------------------------------------------------------------===//
4631// Verifier for AtomicCaptureOp
4632//===----------------------------------------------------------------------===//
4633
4634AtomicReadOp AtomicCaptureOp::getAtomicReadOp() {
4635 if (auto op = dyn_cast<AtomicReadOp>(getFirstOp()))
4636 return op;
4637 return dyn_cast<AtomicReadOp>(getSecondOp());
4638}
4639
4640AtomicWriteOp AtomicCaptureOp::getAtomicWriteOp() {
4641 if (auto op = dyn_cast<AtomicWriteOp>(getFirstOp()))
4642 return op;
4643 return dyn_cast<AtomicWriteOp>(getSecondOp());
4644}
4645
4646AtomicUpdateOp AtomicCaptureOp::getAtomicUpdateOp() {
4647 if (auto op = dyn_cast<AtomicUpdateOp>(getFirstOp()))
4648 return op;
4649 return dyn_cast<AtomicUpdateOp>(getSecondOp());
4650}
4651
4652LogicalResult AtomicCaptureOp::verify() {
4653 return verifySynchronizationHint(*this, getHint());
4654}
4655
4656LogicalResult AtomicCaptureOp::verifyRegions() {
4657 if (verifyRegionsCommon().failed())
4658 return mlir::failure();
4659
4660 if (getFirstOp()->getAttr("hint") || getSecondOp()->getAttr("hint"))
4661 return emitOpError(
4662 "operations inside capture region must not have hint clause");
4663
4664 if (getFirstOp()->getAttr("memory_order") ||
4665 getSecondOp()->getAttr("memory_order"))
4666 return emitOpError(
4667 "operations inside capture region must not have memory_order clause");
4668 return success();
4669}
4670
4671//===----------------------------------------------------------------------===//
4672// AtomicCompareOp
4673//===----------------------------------------------------------------------===//
4674
4675LogicalResult AtomicCompareOp::verify() {
4676 if (verifyCommon().failed())
4677 return mlir::failure();
4678 return verifySynchronizationHint(*this, getHint());
4679}
4680
4681LogicalResult AtomicCompareOp::verifyRegions() {
4682 if (verifyRegionsCommon().failed())
4683 return mlir::failure();
4684
4685 if (verifyOperator().failed())
4686 return mlir::failure();
4687
4688 Block &block = getRegion().front();
4689
4690 Operation *terminator = block.getTerminator();
4691 if (!terminator || !isa<YieldOp>(terminator))
4692 return emitOpError("region must be terminated with omp.yield");
4693
4694 return success();
4695}
4696
4697//===----------------------------------------------------------------------===//
4698// CancelOp
4699//===----------------------------------------------------------------------===//
4700
4701void CancelOp::build(OpBuilder &builder, OperationState &state,
4702 const CancelOperands &clauses) {
4703 CancelOp::build(builder, state, clauses.cancelDirective, clauses.ifExpr);
4704}
4705
4707 Operation *parent = thisOp->getParentOp();
4708 while (parent) {
4709 if (parent->getDialect() == thisOp->getDialect())
4710 return parent;
4711 parent = parent->getParentOp();
4712 }
4713 return nullptr;
4714}
4715
4716LogicalResult CancelOp::verify() {
4717 ClauseCancellationConstructType cct = getCancelDirective();
4718 // The next OpenMP operation in the chain of parents
4719 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4720 if (!structuralParent)
4721 return emitOpError() << "Orphaned cancel construct";
4722
4723 if ((cct == ClauseCancellationConstructType::Parallel) &&
4724 !mlir::isa<ParallelOp>(structuralParent)) {
4725 return emitOpError() << "cancel parallel must appear "
4726 << "inside a parallel region";
4727 }
4728 if (cct == ClauseCancellationConstructType::Loop) {
4729 // structural parent will be omp.loop_nest, directly nested inside
4730 // omp.wsloop
4731 auto wsloopOp = mlir::dyn_cast<WsloopOp>(structuralParent->getParentOp());
4732
4733 if (!wsloopOp) {
4734 return emitOpError()
4735 << "cancel loop must appear inside a worksharing-loop region";
4736 }
4737 if (wsloopOp.getNowaitAttr()) {
4738 return emitError() << "A worksharing construct that is canceled "
4739 << "must not have a nowait clause";
4740 }
4741 if (wsloopOp.getOrderedAttr()) {
4742 return emitError() << "A worksharing construct that is canceled "
4743 << "must not have an ordered clause";
4744 }
4745
4746 } else if (cct == ClauseCancellationConstructType::Sections) {
4747 // structural parent will be an omp.section, directly nested inside
4748 // omp.sections
4749 auto sectionsOp =
4750 mlir::dyn_cast<SectionsOp>(structuralParent->getParentOp());
4751 if (!sectionsOp) {
4752 return emitOpError() << "cancel sections must appear "
4753 << "inside a sections region";
4754 }
4755 if (sectionsOp.getNowait()) {
4756 return emitError() << "A sections construct that is canceled "
4757 << "must not have a nowait clause";
4758 }
4759 }
4760 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4761 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4762 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->getParentOp()))) {
4763 return emitOpError() << "cancel taskgroup must appear "
4764 << "inside a task region";
4765 }
4766 return success();
4767}
4768
4769//===----------------------------------------------------------------------===//
4770// CancellationPointOp
4771//===----------------------------------------------------------------------===//
4772
4773void CancellationPointOp::build(OpBuilder &builder, OperationState &state,
4774 const CancellationPointOperands &clauses) {
4775 CancellationPointOp::build(builder, state, clauses.cancelDirective);
4776}
4777
4778LogicalResult CancellationPointOp::verify() {
4779 ClauseCancellationConstructType cct = getCancelDirective();
4780 // The next OpenMP operation in the chain of parents
4781 Operation *structuralParent = getParentInSameDialect((*this).getOperation());
4782 if (!structuralParent)
4783 return emitOpError() << "Orphaned cancellation point";
4784
4785 if ((cct == ClauseCancellationConstructType::Parallel) &&
4786 !mlir::isa<ParallelOp>(structuralParent)) {
4787 return emitOpError() << "cancellation point parallel must appear "
4788 << "inside a parallel region";
4789 }
4790 // Strucutal parent here will be an omp.loop_nest. Get the parent of that to
4791 // find the wsloop
4792 if ((cct == ClauseCancellationConstructType::Loop) &&
4793 !mlir::isa<WsloopOp>(structuralParent->getParentOp())) {
4794 return emitOpError() << "cancellation point loop must appear "
4795 << "inside a worksharing-loop region";
4796 }
4797 if ((cct == ClauseCancellationConstructType::Sections) &&
4798 !mlir::isa<omp::SectionOp>(structuralParent)) {
4799 return emitOpError() << "cancellation point sections must appear "
4800 << "inside a sections region";
4801 }
4802 if ((cct == ClauseCancellationConstructType::Taskgroup) &&
4803 (!mlir::isa<omp::TaskOp>(structuralParent) &&
4804 !mlir::isa<omp::TaskloopWrapperOp>(structuralParent->getParentOp()))) {
4805 return emitOpError() << "cancellation point taskgroup must appear "
4806 << "inside a task region";
4807 }
4808 return success();
4809}
4810
4811//===----------------------------------------------------------------------===//
4812// MapBoundsOp
4813//===----------------------------------------------------------------------===//
4814
4815LogicalResult MapBoundsOp::verify() {
4816 auto extent = getExtent();
4817 auto upperbound = getUpperBound();
4818 if (!extent && !upperbound)
4819 return emitError("expected extent or upperbound.");
4820 return success();
4821}
4822
4823void PrivateClauseOp::build(OpBuilder &odsBuilder, OperationState &odsState,
4824 TypeRange /*result_types*/, StringAttr symName,
4825 TypeAttr type) {
4826 PrivateClauseOp::build(
4827 odsBuilder, odsState, symName, type,
4828 DataSharingClauseTypeAttr::get(odsBuilder.getContext(),
4829 DataSharingClauseType::Private));
4830}
4831
4832LogicalResult PrivateClauseOp::verifyRegions() {
4833 Type argType = getArgType();
4834 auto verifyTerminator = [&](Operation *terminator,
4835 bool yieldsValue) -> LogicalResult {
4836 if (!terminator->getBlock()->getSuccessors().empty())
4837 return success();
4838
4839 if (!llvm::isa<YieldOp>(terminator))
4840 return mlir::emitError(terminator->getLoc())
4841 << "expected exit block terminator to be an `omp.yield` op.";
4842
4843 YieldOp yieldOp = llvm::cast<YieldOp>(terminator);
4844 TypeRange yieldedTypes = yieldOp.getResults().getTypes();
4845
4846 if (!yieldsValue) {
4847 if (yieldedTypes.empty())
4848 return success();
4849
4850 return mlir::emitError(terminator->getLoc())
4851 << "Did not expect any values to be yielded.";
4852 }
4853
4854 if (yieldedTypes.size() == 1 && yieldedTypes.front() == argType)
4855 return success();
4856
4857 auto error = mlir::emitError(yieldOp.getLoc())
4858 << "Invalid yielded value. Expected type: " << argType
4859 << ", got: ";
4860
4861 if (yieldedTypes.empty())
4862 error << "None";
4863 else
4864 error << yieldedTypes;
4865
4866 return error;
4867 };
4868
4869 auto verifyRegion = [&](Region &region, unsigned expectedNumArgs,
4870 StringRef regionName,
4871 bool yieldsValue) -> LogicalResult {
4872 assert(!region.empty());
4873
4874 if (region.getNumArguments() != expectedNumArgs)
4875 return mlir::emitError(region.getLoc())
4876 << "`" << regionName << "`: "
4877 << "expected " << expectedNumArgs
4878 << " region arguments, got: " << region.getNumArguments();
4879
4880 for (Block &block : region) {
4881 // MLIR will verify the absence of the terminator for us.
4882 if (!block.mightHaveTerminator())
4883 continue;
4884
4885 if (failed(verifyTerminator(block.getTerminator(), yieldsValue)))
4886 return failure();
4887 }
4888
4889 return success();
4890 };
4891
4892 // Ensure all of the region arguments have the same type
4893 for (Region *region : getRegions())
4894 for (Type ty : region->getArgumentTypes())
4895 if (ty != argType)
4896 return emitError() << "Region argument type mismatch: got " << ty
4897 << " expected " << argType << ".";
4898
4899 mlir::Region &initRegion = getInitRegion();
4900 if (!initRegion.empty() &&
4901 failed(verifyRegion(getInitRegion(), /*expectedNumArgs=*/2, "init",
4902 /*yieldsValue=*/true)))
4903 return failure();
4904
4905 DataSharingClauseType dsType = getDataSharingType();
4906
4907 if (dsType == DataSharingClauseType::Private && !getCopyRegion().empty())
4908 return emitError("`private` clauses do not require a `copy` region.");
4909
4910 if (dsType == DataSharingClauseType::FirstPrivate && getCopyRegion().empty())
4911 return emitError(
4912 "`firstprivate` clauses require at least a `copy` region.");
4913
4914 if (dsType == DataSharingClauseType::FirstPrivate &&
4915 failed(verifyRegion(getCopyRegion(), /*expectedNumArgs=*/2, "copy",
4916 /*yieldsValue=*/true)))
4917 return failure();
4918
4919 if (!getDeallocRegion().empty() &&
4920 failed(verifyRegion(getDeallocRegion(), /*expectedNumArgs=*/1, "dealloc",
4921 /*yieldsValue=*/false)))
4922 return failure();
4923
4924 return success();
4925}
4926
4927//===----------------------------------------------------------------------===//
4928// Spec 5.2: Masked construct (10.5)
4929//===----------------------------------------------------------------------===//
4930
4931void MaskedOp::build(OpBuilder &builder, OperationState &state,
4932 const MaskedOperands &clauses) {
4933 MaskedOp::build(builder, state, clauses.filteredThreadId);
4934}
4935
4936//===----------------------------------------------------------------------===//
4937// Spec 5.2: Scan construct (5.6)
4938//===----------------------------------------------------------------------===//
4939
4940void ScanOp::build(OpBuilder &builder, OperationState &state,
4941 const ScanOperands &clauses) {
4942 ScanOp::build(builder, state, clauses.inclusiveVars, clauses.exclusiveVars);
4943}
4944
4945LogicalResult ScanOp::verify() {
4946 if (hasExclusiveVars() == hasInclusiveVars())
4947 return emitError(
4948 "Exactly one of EXCLUSIVE or INCLUSIVE clause is expected");
4949 if (WsloopOp parentWsLoopOp = (*this)->getParentOfType<WsloopOp>()) {
4950 if (parentWsLoopOp.getReductionModAttr() &&
4951 parentWsLoopOp.getReductionModAttr().getValue() ==
4952 ReductionModifier::inscan)
4953 return success();
4954 }
4955 if (SimdOp parentSimdOp = (*this)->getParentOfType<SimdOp>()) {
4956 if (parentSimdOp.getReductionModAttr() &&
4957 parentSimdOp.getReductionModAttr().getValue() ==
4958 ReductionModifier::inscan)
4959 return success();
4960 }
4961 return emitError("SCAN directive needs to be enclosed within a parent "
4962 "worksharing loop construct or SIMD construct with INSCAN "
4963 "reduction modifier");
4964}
4965
4966/// Verifies align clause in allocate directive
4967LogicalResult verifyAlignment(Operation &op,
4968 std::optional<uint64_t> alignment) {
4969 if (alignment.has_value()) {
4970 if ((alignment.value() != 0) && !llvm::has_single_bit(alignment.value()))
4971 return op.emitError()
4972 << "ALIGN value : " << alignment.value() << " must be power of 2";
4973 }
4974 return success();
4975}
4976
4977LogicalResult AllocateDirOp::verify() {
4978 return verifyAlignment(*getOperation(), getAlign());
4979}
4980
4981//===----------------------------------------------------------------------===//
4982// AllocSharedMemOp
4983//===----------------------------------------------------------------------===//
4984
4985LogicalResult AllocSharedMemOp::verify() {
4986 return verifyAlignment(*getOperation(), getMemAlignment());
4987}
4988
4989//===----------------------------------------------------------------------===//
4990// FreeSharedMemOp
4991//===----------------------------------------------------------------------===//
4992
4993LogicalResult FreeSharedMemOp::verify() {
4994 return verifyAlignment(*getOperation(), getMemAlignment());
4995}
4996
4997//===----------------------------------------------------------------------===//
4998// WorkdistributeOp
4999//===----------------------------------------------------------------------===//
5000
5001LogicalResult WorkdistributeOp::verify() {
5002 // Check that region exists and is not empty
5003 Region &region = getRegion();
5004 if (region.empty())
5005 return emitOpError("region cannot be empty");
5006 // Verify single entry point.
5007 Block &entryBlock = region.front();
5008 if (entryBlock.empty())
5009 return emitOpError("region must contain a structured block");
5010 // Verify single exit point.
5011 bool hasTerminator = false;
5012 for (Block &block : region) {
5013 if (isa<TerminatorOp>(block.back())) {
5014 if (hasTerminator) {
5015 return emitOpError("region must have exactly one terminator");
5016 }
5017 hasTerminator = true;
5018 }
5019 }
5020 if (!hasTerminator) {
5021 return emitOpError("region must be terminated with omp.terminator");
5022 }
5023 auto walkResult = region.walk([&](Operation *op) -> WalkResult {
5024 // No implicit barrier at end
5025 if (isa<BarrierOp>(op)) {
5026 return emitOpError(
5027 "explicit barriers are not allowed in workdistribute region");
5028 }
5029 // Check for invalid nested constructs
5030 if (isa<ParallelOp>(op)) {
5031 return emitOpError(
5032 "nested parallel constructs not allowed in workdistribute");
5033 }
5034 if (isa<TeamsOp>(op)) {
5035 return emitOpError(
5036 "nested teams constructs not allowed in workdistribute");
5037 }
5038 return WalkResult::advance();
5039 });
5040 if (walkResult.wasInterrupted())
5041 return failure();
5042
5043 Operation *parentOp = (*this)->getParentOp();
5044 if (!llvm::dyn_cast<TeamsOp>(parentOp))
5045 return emitOpError("workdistribute must be nested under teams");
5046 return success();
5047}
5048
5049//===----------------------------------------------------------------------===//
5050// Declare simd [7.7]
5051//===----------------------------------------------------------------------===//
5052
5053LogicalResult DeclareSimdOp::verify() {
5054 // Must be nested inside a function-like op
5055 auto func =
5056 dyn_cast_if_present<mlir::FunctionOpInterface>((*this)->getParentOp());
5057 if (!func)
5058 return emitOpError() << "must be nested inside a function";
5059
5060 if (getInbranch() && getNotinbranch())
5061 return emitOpError("cannot have both 'inbranch' and 'notinbranch'");
5062
5063 if (failed(verifyLinearModifiers(*this, getLinearModifiers(), getLinearVars(),
5064 /*isDeclareSimd=*/true)))
5065 return failure();
5066
5067 return verifyAlignedClause(*this, getAlignments(), getAlignedVars());
5068}
5069
5070void DeclareSimdOp::build(OpBuilder &odsBuilder, OperationState &odsState,
5071 const DeclareSimdOperands &clauses) {
5072 MLIRContext *ctx = odsBuilder.getContext();
5073 DeclareSimdOp::build(odsBuilder, odsState, clauses.alignedVars,
5074 makeArrayAttr(ctx, clauses.alignments), clauses.inbranch,
5075 clauses.linearVars, clauses.linearStepVars,
5076 clauses.linearVarTypes, clauses.linearModifiers,
5077 clauses.notinbranch, clauses.simdlen,
5078 clauses.uniformVars);
5079}
5080
5081//===----------------------------------------------------------------------===//
5082// Parser and printer for Uniform Clause
5083//===----------------------------------------------------------------------===//
5084
5085/// uniform ::= `uniform` `(` uniform-list `)`
5086/// uniform-list := uniform-val (`,` uniform-val)*
5087/// uniform-val := ssa-id `:` type
5088static ParseResult
5091 SmallVectorImpl<Type> &uniformTypes) {
5092 return parser.parseCommaSeparatedList([&]() -> mlir::ParseResult {
5093 if (parser.parseOperand(uniformVars.emplace_back()) ||
5094 parser.parseColonType(uniformTypes.emplace_back()))
5095 return mlir::failure();
5096 return mlir::success();
5097 });
5098}
5099
5100/// Print Uniform Clauses
5102 ValueRange uniformVars, TypeRange uniformTypes) {
5103 for (unsigned i = 0; i < uniformVars.size(); ++i) {
5104 if (i != 0)
5105 p << ", ";
5106 p << uniformVars[i] << " : " << uniformTypes[i];
5107 }
5108}
5109
5110//===----------------------------------------------------------------------===//
5111// Parser and printer for Affinity Clause
5112//===----------------------------------------------------------------------===//
5113
5114static ParseResult parseAffinityClause(
5115 OpAsmParser &parser,
5118 SmallVectorImpl<Type> &iteratedTypes,
5119 SmallVectorImpl<Type> &affinityVarTypes) {
5120 if (failed(parseSplitIteratedList(
5121 parser, iterated, iteratedTypes, affinityVars, affinityVarTypes,
5122 /*parsePrefix=*/[&]() -> ParseResult { return success(); })))
5123 return failure();
5124 return success();
5125}
5126
5128 ValueRange iterated, ValueRange affinityVars,
5129 TypeRange iteratedTypes,
5130 TypeRange affinityVarTypes) {
5131 auto nop = [&](Value, Type) {};
5132 printSplitIteratedList(p, iterated, iteratedTypes, affinityVars,
5133 affinityVarTypes,
5134 /*plain prefix*/ nop,
5135 /*iterated prefix*/ nop);
5136}
5137
5138//===----------------------------------------------------------------------===//
5139// Parser, printer, and verifier for Iterator modifier
5140//===----------------------------------------------------------------------===//
5141
5142static ParseResult
5147 SmallVectorImpl<Type> &lbTypes,
5148 SmallVectorImpl<Type> &ubTypes,
5149 SmallVectorImpl<Type> &stepTypes) {
5150
5151 llvm::SMLoc ivLoc = parser.getCurrentLocation();
5153
5154 // Parse induction variables: %i : i32, %j : i32
5155 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
5156 OpAsmParser::Argument &arg = ivArgs.emplace_back();
5157 if (parser.parseArgument(arg))
5158 return failure();
5159
5160 // Optional type, default to Index if not provided
5161 if (succeeded(parser.parseOptionalColon())) {
5162 if (parser.parseType(arg.type))
5163 return failure();
5164 } else {
5165 arg.type = parser.getBuilder().getIndexType();
5166 }
5167 return success();
5168 }))
5169 return failure();
5170
5171 // ) = (
5172 if (parser.parseRParen() || parser.parseEqual() || parser.parseLParen())
5173 return failure();
5174
5175 // Parse Ranges: (%lb to %ub step %st, ...)
5176 if (parser.parseCommaSeparatedList([&]() -> ParseResult {
5177 OpAsmParser::UnresolvedOperand lb, ub, st;
5178 if (parser.parseOperand(lb) || parser.parseKeyword("to") ||
5179 parser.parseOperand(ub) || parser.parseKeyword("step") ||
5180 parser.parseOperand(st))
5181 return failure();
5182
5183 lbs.push_back(lb);
5184 ubs.push_back(ub);
5185 steps.push_back(st);
5186 return success();
5187 }))
5188 return failure();
5189
5190 if (parser.parseRParen())
5191 return failure();
5192
5193 if (ivArgs.size() != lbs.size())
5194 return parser.emitError(ivLoc)
5195 << "mismatch: " << ivArgs.size() << " variables but " << lbs.size()
5196 << " ranges";
5197
5198 for (auto &arg : ivArgs) {
5199 lbTypes.push_back(arg.type);
5200 ubTypes.push_back(arg.type);
5201 stepTypes.push_back(arg.type);
5202 }
5203
5204 return parser.parseRegion(region, ivArgs);
5205}
5206
5208 ValueRange lbs, ValueRange ubs,
5210 TypeRange) {
5211 Block &entry = region.front();
5212
5213 for (unsigned i = 0, e = entry.getNumArguments(); i < e; ++i) {
5214 if (i != 0)
5215 p << ", ";
5216 p.printRegionArgument(entry.getArgument(i));
5217 }
5218 p << ") = (";
5219
5220 // (%lb0 to %ub0 step %step0, %lb1 to %ub1 step %step1, ...)
5221 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
5222 if (i)
5223 p << ", ";
5224 p << lbs[i] << " to " << ubs[i] << " step " << steps[i];
5225 }
5226 p << ") ";
5227
5228 p.printRegion(region, /*printEntryBlockArgs=*/false,
5229 /*printBlockTerminators=*/true);
5230}
5231
5232LogicalResult IteratorOp::verify() {
5233 auto iteratedTy = llvm::dyn_cast<omp::IteratedType>(getIterated().getType());
5234 if (!iteratedTy)
5235 return emitOpError() << "result must be omp.iterated<entry_ty>";
5236
5237 for (auto [lb, ub, step] : llvm::zip_equal(
5238 getLoopLowerBounds(), getLoopUpperBounds(), getLoopSteps())) {
5239 if (matchPattern(step, m_Zero()))
5240 return emitOpError() << "loop step must not be zero";
5241
5242 IntegerAttr lbAttr;
5243 IntegerAttr ubAttr;
5244 IntegerAttr stepAttr;
5245 if (!matchPattern(lb, m_Constant(&lbAttr)) ||
5246 !matchPattern(ub, m_Constant(&ubAttr)) ||
5247 !matchPattern(step, m_Constant(&stepAttr)))
5248 continue;
5249
5250 const APInt &lbVal = lbAttr.getValue();
5251 const APInt &ubVal = ubAttr.getValue();
5252 const APInt &stepVal = stepAttr.getValue();
5253 if (stepVal.isStrictlyPositive() && lbVal.sgt(ubVal))
5254 return emitOpError() << "positive loop step requires lower bound to be "
5255 "less than or equal to upper bound";
5256 if (stepVal.isNegative() && lbVal.slt(ubVal))
5257 return emitOpError() << "negative loop step requires lower bound to be "
5258 "greater than or equal to upper bound";
5259 }
5260
5261 Block &b = getRegion().front();
5262 auto yield = llvm::dyn_cast<omp::YieldOp>(b.getTerminator());
5263
5264 if (!yield)
5265 return emitOpError() << "region must be terminated by omp.yield";
5266
5267 if (yield.getNumOperands() != 1)
5268 return emitOpError()
5269 << "omp.yield in omp.iterator region must yield exactly one value";
5270
5271 mlir::Type yieldedTy = yield.getOperand(0).getType();
5272 mlir::Type elemTy = iteratedTy.getElementType();
5273
5274 if (yieldedTy != elemTy)
5275 return emitOpError() << "omp.iterated element type (" << elemTy
5276 << ") does not match omp.yield operand type ("
5277 << yieldedTy << ")";
5278
5279 return success();
5280}
5281
5282//===----------------------------------------------------------------------===//
5283// GroupprivateOp
5284//===----------------------------------------------------------------------===//
5285
5286LogicalResult
5287GroupprivateOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
5288 auto *symbol = symbolTable.lookupNearestSymbolFrom(*this, getSymNameAttr());
5289 if (!symbol)
5290 return emitOpError() << "expected symbol reference '" << getSymName()
5291 << "' to point to a global variable";
5292
5293 if (isa<FunctionOpInterface>(symbol))
5294 return emitOpError() << "expected symbol reference '" << getSymName()
5295 << "' to point to a global variable, not a function";
5296
5297 return success();
5298}
5299
5300#define GET_ATTRDEF_CLASSES
5301#include "mlir/Dialect/OpenMP/OpenMPOpsAttributes.cpp.inc"
5302
5303#define GET_OP_CLASSES
5304#include "mlir/Dialect/OpenMP/OpenMPOps.cpp.inc"
5305
5306#define GET_TYPEDEF_CLASSES
5307#include "mlir/Dialect/OpenMP/OpenMPOpsTypes.cpp.inc"
return success()
p<< " : "<< getMemRefType()<< ", "<< getType();}static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType, VectorType vectorType) { if(memrefType.getElementType() !=vectorType.getElementType()) return op-> emitOpError("requires memref and vector types of the same elemental type")
Given a list of lists of parsed operands, populates uniqueOperands with unique operands.
static std::optional< int64_t > getUpperBound(Value iv)
Gets the constant upper bound on an affine.for iv.
static LogicalResult verifyRegion(emitc::SwitchOp op, Region &region, const Twine &name)
Definition EmitC.cpp:1533
static Type getElementType(Type type)
Determine the element type of type.
b
Return true if permutation is a valid permutation of the outer_dims_perm (case OuterOrInnerPerm::Oute...
ArrayAttr()
b getContext())
static const mlir::GenInfo * generator
static LogicalResult verifyNontemporalClause(Operation *op, OperandRange nontemporalVars)
static DenseI64ArrayAttr makeDenseI64ArrayAttr(MLIRContext *ctx, const ArrayRef< int64_t > intArray)
static void printDependVarList(OpAsmPrinter &p, Operation *op, OperandRange dependVars, TypeRange dependTypes, std::optional< ArrayAttr > dependKinds, OperandRange iteratedVars, TypeRange iteratedTypes, std::optional< ArrayAttr > iteratedKinds)
Print Depend clause.
static constexpr StringRef getPrivateNeedsBarrierSpelling()
static void printHeapAllocClause(OpAsmPrinter &p, Operation *op, TypeAttr inType, ValueRange typeparams, TypeRange typeparamsTypes, ValueRange shape, TypeRange shapeTypes)
static LogicalResult verifyReductionVarList(Operation *op, std::optional< ArrayAttr > reductionSyms, OperandRange reductionVars, std::optional< ArrayRef< bool > > reductionByref)
Verifies Reduction Clause.
static ParseResult parseLinearClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearVars, SmallVectorImpl< Type > &linearTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &linearStepVars, SmallVectorImpl< Type > &linearStepTypes, ArrayAttr &linearModifiers)
linear ::= linear ( linear-list ) linear-list := linear-val | linear-val linear-list linear-val := ss...
static ParseResult parseInReductionPrivateRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static ArrayAttr makeArrayAttr(MLIRContext *context, llvm::ArrayRef< Attribute > attrs)
static ParseResult parseClauseAttr(AsmParser &parser, ClauseAttr &attr)
static void printDynGroupprivateClause(OpAsmPrinter &printer, Operation *op, AccessGroupModifierAttr modifierFirst, FallbackModifierAttr modifierSecond, Value dynGroupprivateSize, Type sizeType)
static void printAllocateAndAllocator(OpAsmPrinter &p, Operation *op, OperandRange allocateVars, TypeRange allocateTypes, OperandRange allocatorVars, TypeRange allocatorTypes)
Print allocate clause.
static DenseBoolArrayAttr makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef< bool > boolArray)
static std::string generateLoopNestingName(StringRef prefix, CanonicalLoopOp op)
Generate a name of a canonical loop nest of the format <prefix>(_r<idx>_s<idx>)*.
static ParseResult parseAffinityClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iterated, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &affinityVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< Type > &affinityVarTypes)
static void printClauseWithRegionArgs(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, ValueRange operands, TypeRange types, ArrayAttr symbols=nullptr, DenseI64ArrayAttr mapIndices=nullptr, DenseBoolArrayAttr byref=nullptr, ReductionModifierAttr modifier=nullptr, UnitAttr needsBarrier=nullptr)
static void printSplitIteratedList(OpAsmPrinter &p, ValueRange iteratedVars, TypeRange iteratedTypes, ValueRange plainVars, TypeRange plainTypes, PrintPrefixFn &&printPrefixForPlain, PrintPrefixFn &&printPrefixForIterated)
static LogicalResult verifyDependVarList(Operation *op, std::optional< ArrayAttr > dependKinds, OperandRange dependVars, std::optional< ArrayAttr > iteratedKinds, OperandRange iteratedVars)
Verifies Depend clause.
static void printBlockArgClause(OpAsmPrinter &p, MLIRContext *ctx, StringRef clauseName, ValueRange argsSubrange, std::optional< MapPrintArgs > mapArgs)
static void printAffinityClause(OpAsmPrinter &p, Operation *op, ValueRange iterated, ValueRange affinityVars, TypeRange iteratedTypes, TypeRange affinityVarTypes)
static void printBlockArgRegion(OpAsmPrinter &p, Operation *op, Region &region, const AllRegionPrintArgs &args)
static ParseResult parseGranularityClause(OpAsmParser &parser, ClauseTypeAttr &prescriptiveness, std::optional< OpAsmParser::UnresolvedOperand > &operand, Type &operandType, std::optional< ClauseType >(*symbolizeClause)(StringRef), StringRef clauseName)
static void printIteratorHeader(OpAsmPrinter &p, Operation *op, Region &region, ValueRange lbs, ValueRange ubs, ValueRange steps, TypeRange, TypeRange, TypeRange)
static ParseResult parseHeapAllocClause(OpAsmParser &parser, TypeAttr &inTypeAttr, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &typeparams, SmallVectorImpl< Type > &typeparamsTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &shape, SmallVectorImpl< Type > &shapeTypes)
operation ::= $in_type ( ( $typeparams ) )? ( , $shape )?
static ParseResult parseIteratorHeader(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &lbs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &ubs, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &steps, SmallVectorImpl< Type > &lbTypes, SmallVectorImpl< Type > &ubTypes, SmallVectorImpl< Type > &stepTypes)
static ParseResult parseBlockArgRegion(OpAsmParser &parser, Region &region, AllRegionParseArgs args)
static ParseResult parseLoopTransformClis(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &generateesOperands, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &applyeesOperands)
static ParseResult parseSynchronizationHint(OpAsmParser &parser, IntegerAttr &hintAttr)
Parses a Synchronization Hint clause.
static void printScheduleClause(OpAsmPrinter &p, Operation *op, ClauseScheduleKindAttr scheduleKind, ScheduleModifierAttr scheduleMod, UnitAttr scheduleSimd, Value scheduleChunk, Type scheduleChunkType)
Print schedule clause.
static void printCopyprivate(OpAsmPrinter &p, Operation *op, OperandRange copyprivateVars, TypeRange copyprivateTypes, std::optional< ArrayAttr > copyprivateSyms)
Print Copyprivate clause.
static ParseResult parseOrderClause(OpAsmParser &parser, ClauseOrderKindAttr &order, OrderModifierAttr &orderMod)
static bool mapTypeToBool(ClauseMapFlags value, ClauseMapFlags flag)
static void printAlignedClause(OpAsmPrinter &p, Operation *op, ValueRange alignedVars, TypeRange alignedTypes, std::optional< ArrayAttr > alignments)
Print Aligned Clause.
static LogicalResult verifySynchronizationHint(Operation *op, uint64_t hint)
Verifies a synchronization hint clause.
static ParseResult parseUseDeviceAddrUseDevicePtrRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDeviceAddrVars, SmallVectorImpl< Type > &useDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &useDevicePtrVars, SmallVectorImpl< Type > &useDevicePtrTypes)
static Operation * findCapturedOmpOp(Operation *rootOp, bool checkSingleMandatoryExec, llvm::function_ref< bool(Operation *)> siblingAllowedFn)
static ParseResult parseUniformClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &uniformVars, SmallVectorImpl< Type > &uniformTypes)
uniform ::= uniform ( uniform-list ) uniform-list := uniform-val (, uniform-val)* uniform-val := ssa-...
static void printInReductionPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static void printInReductionPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printSynchronizationHint(OpAsmPrinter &p, Operation *op, IntegerAttr hintAttr)
Prints a Synchronization Hint clause.
static void printGranularityClause(OpAsmPrinter &p, Operation *op, ClauseTypeAttr prescriptiveness, Value operand, mlir::Type operandType, StringRef(*stringifyClauseType)(ClauseType))
static ParseResult parseDependVarList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &dependVars, SmallVectorImpl< Type > &dependTypes, ArrayAttr &dependKinds, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, ArrayAttr &iteratedKinds)
depend-entry-list ::= depend-entry | depend-entry-list , depend-entry depend-entry ::= depend-kind ->...
static void printTargetOpRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange hasDeviceAddrVars, TypeRange hasDeviceAddrTypes, ValueRange hostEvalVars, TypeRange hostEvalTypes, ValueRange inReductionVars, TypeRange inReductionTypes, DenseBoolArrayAttr inReductionByref, ArrayAttr inReductionSyms, ValueRange mapVars, TypeRange mapTypes, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, DenseI64ArrayAttr privateMaps)
static Operation * getParentInSameDialect(Operation *thisOp)
static void printUniformClause(OpAsmPrinter &p, Operation *op, ValueRange uniformVars, TypeRange uniformTypes)
Print Uniform Clauses.
static LogicalResult verifyCopyprivateVarList(Operation *op, OperandRange copyprivateVars, std::optional< ArrayAttr > copyprivateSyms)
Verifies CopyPrivate Clause.
static LogicalResult verifyAlignedClause(Operation *op, std::optional< ArrayAttr > alignments, OperandRange alignedVars)
static ParseResult parseTargetOpRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hasDeviceAddrVars, SmallVectorImpl< Type > &hasDeviceAddrTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &hostEvalVars, SmallVectorImpl< Type > &hostEvalTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &mapVars, SmallVectorImpl< Type > &mapTypes, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, DenseI64ArrayAttr &privateMaps)
static ParseResult parsePrivateRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier)
static void printNumTasksClause(OpAsmPrinter &p, Operation *op, ClauseNumTasksTypeAttr numTasksMod, Value numTasks, mlir::Type numTasksType)
static void printLoopTransformClis(OpAsmPrinter &p, TileOp op, OperandRange generatees, OperandRange applyees)
static ParseResult parseDynGroupprivateClause(OpAsmParser &parser, AccessGroupModifierAttr &accessGroupAttr, FallbackModifierAttr &fallbackAttr, std::optional< OpAsmParser::UnresolvedOperand > &dynGroupprivateSize, Type &sizeType)
static void printPrivateRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier)
static void printPrivateReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange privateVars, TypeRange privateTypes, ArrayAttr privateSyms, UnitAttr privateNeedsBarrier, ReductionModifierAttr reductionMod, ValueRange reductionVars, TypeRange reductionTypes, DenseBoolArrayAttr reductionByref, ArrayAttr reductionSyms)
static ParseResult parseSplitIteratedList(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &iteratedVars, SmallVectorImpl< Type > &iteratedTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &plainVars, SmallVectorImpl< Type > &plainTypes, ParsePrefixFn &&parsePrefix)
static void printTaskReductionRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange taskReductionVars, TypeRange taskReductionTypes, DenseBoolArrayAttr taskReductionByref, ArrayAttr taskReductionSyms)
static LogicalResult verifyMapInfoForMapClause(Operation *op, mlir::omp::MapInfoOp mapInfoOp, llvm::DenseSet< mlir::TypedValue< mlir::omp::PointerLikeType > > &updateToVars, llvm::DenseSet< mlir::TypedValue< mlir::omp::PointerLikeType > > &updateFromVars)
return success()
static LogicalResult verifyOrderedParent(Operation &op)
static void printOrderClause(OpAsmPrinter &p, Operation *op, ClauseOrderKindAttr order, OrderModifierAttr orderMod)
static ParseResult parseBlockArgClause(OpAsmParser &parser, llvm::SmallVectorImpl< OpAsmParser::Argument > &entryBlockArgs, StringRef keyword, std::optional< MapParseArgs > mapArgs)
static ParseResult parseClauseWithRegionArgs(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &operands, SmallVectorImpl< Type > &types, SmallVectorImpl< OpAsmParser::Argument > &regionPrivateArgs, ArrayAttr *symbols=nullptr, DenseI64ArrayAttr *mapIndices=nullptr, DenseBoolArrayAttr *byref=nullptr, ReductionModifierAttr *modifier=nullptr, UnitAttr *needsBarrier=nullptr)
static LogicalResult verifyPrivateVarsMapping(TargetOp targetOp)
static ParseResult parseScheduleClause(OpAsmParser &parser, ClauseScheduleKindAttr &scheduleAttr, ScheduleModifierAttr &scheduleMod, UnitAttr &scheduleSimd, std::optional< OpAsmParser::UnresolvedOperand > &chunkSize, Type &chunkType)
schedule ::= schedule ( sched-list ) sched-list ::= sched-val | sched-val sched-list | sched-val ,...
static LogicalResult verifyDynGroupprivateClause(Operation *op, AccessGroupModifierAttr accessGroup, FallbackModifierAttr fallback, Value dynGroupprivateSize)
static LogicalResult verifyLinearModifiers(Operation *op, std::optional< ArrayAttr > linearModifiers, OperandRange linearVars, bool isDeclareSimd=false)
OpenMP 5.2, Section 5.4.6: "A linear-modifier may be specified as ref or uval only on a declare simd ...
static void printClauseAttr(OpAsmPrinter &p, Operation *op, ClauseAttr attr)
static ParseResult parseAllocateAndAllocator(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocateVars, SmallVectorImpl< Type > &allocateTypes, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &allocatorVars, SmallVectorImpl< Type > &allocatorTypes)
Parse an allocate clause with allocators and a list of operands with types.
static void printMembersIndex(OpAsmPrinter &p, MapInfoOp op, ArrayAttr membersIdx)
static bool canPromoteToNoLoop(Operation *capturedOp, TeamsOp teamsOp, WsloopOp *wsLoopOp)
Check if we can promote SPMD kernel to No-Loop kernel.
static void printCaptureType(OpAsmPrinter &p, Operation *op, VariableCaptureKindAttr mapCaptureType)
static LogicalResult verifyNumTeamsClause(Operation *op, Value numTeamsLower, OperandRange numTeamsUpperVars)
static bool opInGlobalImplicitParallelRegion(Operation *op)
static void printUseDeviceAddrUseDevicePtrRegion(OpAsmPrinter &p, Operation *op, Region &region, ValueRange useDeviceAddrVars, TypeRange useDeviceAddrTypes, ValueRange useDevicePtrVars, TypeRange useDevicePtrTypes)
static LogicalResult verifyMapClause(Operation *op, OperandRange mapVars, OperandRange mapIterated)
static LogicalResult verifyPrivateVarList(OpType &op)
static ParseResult parseNumTasksClause(OpAsmParser &parser, ClauseNumTasksTypeAttr &numTasksMod, std::optional< OpAsmParser::UnresolvedOperand > &numTasks, Type &numTasksType)
LogicalResult verifyAlignment(Operation &op, std::optional< uint64_t > alignment)
Verifies align clause in allocate directive.
static ParseResult parseAlignedClause(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &alignedVars, SmallVectorImpl< Type > &alignedTypes, ArrayAttr &alignmentsAttr)
aligned ::= aligned ( aligned-list ) aligned-list := aligned-val | aligned-val aligned-list aligned-v...
static ParseResult parsePrivateReductionRegion(OpAsmParser &parser, Region &region, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static void printLinearClause(OpAsmPrinter &p, Operation *op, ValueRange linearVars, TypeRange linearTypes, ValueRange linearStepVars, TypeRange stepVarTypes, ArrayAttr linearModifiers)
Print Linear Clause.
static ParseResult parseInReductionPrivateReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &inReductionVars, SmallVectorImpl< Type > &inReductionTypes, DenseBoolArrayAttr &inReductionByref, ArrayAttr &inReductionSyms, llvm::SmallVectorImpl< OpAsmParser::UnresolvedOperand > &privateVars, llvm::SmallVectorImpl< Type > &privateTypes, ArrayAttr &privateSyms, UnitAttr &privateNeedsBarrier, ReductionModifierAttr &reductionMod, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &reductionVars, SmallVectorImpl< Type > &reductionTypes, DenseBoolArrayAttr &reductionByref, ArrayAttr &reductionSyms)
static LogicalResult checkApplyeesNesting(TileOp op)
Check properties of the loop nest consisting of the transformation's applyees:
static ParseResult parseCaptureType(OpAsmParser &parser, VariableCaptureKindAttr &mapCaptureType)
static ParseResult parseTaskReductionRegion(OpAsmParser &parser, Region &region, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &taskReductionVars, SmallVectorImpl< Type > &taskReductionTypes, DenseBoolArrayAttr &taskReductionByref, ArrayAttr &taskReductionSyms)
static ParseResult parseGrainsizeClause(OpAsmParser &parser, ClauseGrainsizeTypeAttr &grainsizeMod, std::optional< OpAsmParser::UnresolvedOperand > &grainsize, Type &grainsizeType)
static ParseResult parseCopyprivate(OpAsmParser &parser, SmallVectorImpl< OpAsmParser::UnresolvedOperand > &copyprivateVars, SmallVectorImpl< Type > &copyprivateTypes, ArrayAttr &copyprivateSyms)
copyprivate-entry-list ::= copyprivate-entry | copyprivate-entry-list , copyprivate-entry copyprivate...
static LogicalResult verifyMapInfoDefinedArgs(Operation *op, StringRef clauseName, OperandRange vars)
static void printGrainsizeClause(OpAsmPrinter &p, Operation *op, ClauseGrainsizeTypeAttr grainsizeMod, Value grainsize, mlir::Type grainsizeType)
static ParseResult verifyScheduleModifiers(OpAsmParser &parser, SmallVectorImpl< SmallString< 12 > > &modifiers)
static bool isUnique(It begin, It end)
Definition ShardOps.cpp:161
static LogicalResult emit(SolverOp solver, const SMTEmissionOptions &options, mlir::raw_indented_ostream &stream)
Emit the SMT operations in the given 'solver' to the 'stream'.
static SmallVector< Value > getTileSizes(Location loc, x86::amx::TileType tType, RewriterBase &rewriter)
Maps the 2-dim vector shape to the two 16-bit tile sizes.
This base class exposes generic asm parser hooks, usable across the various derived parsers.
virtual ParseResult parseMinus()=0
Parse a '-' token.
@ Paren
Parens surrounding zero or more operands.
@ None
Zero or more operands with no delimiters.
virtual ParseResult parseColonTypeList(SmallVectorImpl< Type > &result)=0
Parse a colon followed by a type list, which must have at least one type.
virtual Builder & getBuilder() const =0
Return a builder which provides useful access to MLIRContext, global objects like types and attribute...
virtual ParseResult parseCommaSeparatedList(Delimiter delimiter, function_ref< ParseResult()> parseElementFn, StringRef contextMessage=StringRef())=0
Parse a list of comma-separated items with an optional delimiter.
virtual ParseResult parseOptionalAttrDict(NamedAttrList &result)=0
Parse a named dictionary into 'result' if it is present.
virtual ParseResult parseOptionalEqual()=0
Parse a = token if present.
virtual ParseResult parseOptionalKeyword(StringRef keyword)=0
Parse the given keyword if present.
MLIRContext * getContext() const
virtual ParseResult parseRParen()=0
Parse a ) token.
virtual InFlightDiagnostic emitError(SMLoc loc, const Twine &message={})=0
Emit a diagnostic at the specified location and return failure.
virtual ParseResult parseOptionalColon()=0
Parse a : token if present.
virtual ParseResult parseLSquare()=0
Parse a [ token.
virtual ParseResult parseRSquare()=0
Parse a ] token.
ParseResult parseInteger(IntT &result)
Parse an integer value from the stream.
virtual ParseResult parseOptionalArrow()=0
Parse a '->' token if present.
virtual ParseResult parseLess()=0
Parse a '<' token.
virtual ParseResult parseEqual()=0
Parse a = token.
virtual ParseResult parseColonType(Type &result)=0
Parse a colon followed by a type.
virtual SMLoc getCurrentLocation()=0
Get the location of the next token and store it into the argument.
virtual ParseResult parseOptionalComma()=0
Parse a , token if present.
virtual ParseResult parseColon()=0
Parse a : token.
virtual SMLoc getNameLoc() const =0
Return the location of the original name token.
virtual ParseResult parseOptionalLess()=0
Parse a '<' token if present.
virtual ParseResult parseArrow()=0
Parse a '->' token.
virtual ParseResult parseLParen()=0
Parse a ( token.
virtual ParseResult parseType(Type &result)=0
Parse a type.
virtual ParseResult parseComma()=0
Parse a , token.
virtual ParseResult parseOptionalLParen()=0
Parse a ( token if present.
ParseResult parseKeyword(StringRef keyword)
Parse a given keyword.
Attributes are known-constant values of operations.
Definition Attributes.h:25
Block represents an ordered list of Operations.
Definition Block.h:33
ValueTypeRange< BlockArgListType > getArgumentTypes()
Return a range containing the types of the arguments for this block.
Definition Block.cpp:154
bool empty()
Definition Block.h:158
BlockArgument getArgument(unsigned i)
Definition Block.h:139
unsigned getNumArguments()
Definition Block.h:138
Operation & front()
Definition Block.h:163
SuccessorRange getSuccessors()
Definition Block.h:280
Operation & back()
Definition Block.h:162
Operation * getTerminator()
Get the terminator operation of this block.
Definition Block.cpp:249
bool mightHaveTerminator()
Return "true" if this block might have a terminator.
Definition Block.cpp:255
BlockArgListType getArguments()
Definition Block.h:97
IntegerType getI64Type()
Definition Builders.cpp:69
IntegerType getIntegerType(unsigned width)
Definition Builders.cpp:71
MLIRContext * getContext() const
Definition Builders.h:56
Attr getAttr(Args &&...args)
Get or construct an instance of the attribute Attr with provided arguments.
Definition Builders.h:100
Diagnostic & append(Arg1 &&arg1, Arg2 &&arg2, Args &&...args)
Append arguments to the diagnostic.
Diagnostic & appendOp(Operation &op, const OpPrintingFlags &flags)
Append an operation with the given printing flags.
Dialects are groups of MLIR operations, types and attributes, as well as behavior associated with the...
Definition Dialect.h:38
A class for computing basic dominance information.
Definition Dominance.h:143
bool dominates(Operation *a, Operation *b) const
Return true if operation A dominates operation B, i.e.
Definition Dominance.h:161
This class represents a diagnostic that is inflight and set to be reported.
Diagnostic & attachNote(std::optional< Location > noteLoc=std::nullopt)
Attaches a note to this diagnostic.
MLIRContext is the top-level object for a collection of MLIR operations.
Definition MLIRContext.h:63
The OpAsmParser has methods for interacting with the asm parser: parsing things from it,...
virtual ParseResult parseRegion(Region &region, ArrayRef< Argument > arguments={}, bool enableNameShadowing=false)=0
Parses a region.
virtual ParseResult parseArgument(Argument &result, bool allowType=false, bool allowAttrs=false)=0
Parse a single argument with the following syntax:
virtual ParseResult parseArgumentList(SmallVectorImpl< Argument > &result, Delimiter delimiter=Delimiter::None, bool allowType=false, bool allowAttrs=false)=0
Parse zero or more arguments with a specified surrounding delimiter.
virtual ParseResult resolveOperand(const UnresolvedOperand &operand, Type type, SmallVectorImpl< Value > &result)=0
Resolve an operand to an SSA value, emitting an error on failure.
ParseResult resolveOperands(Operands &&operands, Type type, SmallVectorImpl< Value > &result)
Resolve a list of operands to SSA values, emitting an error on failure, or appending the results to t...
virtual ParseResult parseOperand(UnresolvedOperand &result, bool allowResultNumber=true)=0
Parse a single SSA value operand name along with a result number if allowResultNumber is true.
virtual ParseResult parseOperandList(SmallVectorImpl< UnresolvedOperand > &result, Delimiter delimiter=Delimiter::None, bool allowResultNumber=true, int requiredOperandCount=-1)=0
Parse zero or more SSA comma-separated operand references with a specified surrounding delimiter,...
This is a pure-virtual base class that exposes the asmprinter hooks necessary to implement a custom p...
virtual void printOptionalAttrDict(ArrayRef< NamedAttribute > attrs, ArrayRef< StringRef > elidedAttrs={})=0
If the specified operation has attributes, print out an attribute dictionary with their values.
virtual void printRegion(Region &blocks, bool printEntryBlockArgs=true, bool printBlockTerminators=true, bool printEmptyBlock=false)=0
Prints a region.
virtual void printRegionArgument(BlockArgument arg, ArrayRef< NamedAttribute > argAttrs={}, bool omitType=false)=0
Print a block argument in the usual format of: ssaName : type {attr1=42} loc("here") where location p...
virtual void printOperand(Value value)=0
Print implementations for various things an operation contains.
This class helps build Operations.
Definition Builders.h:209
This class represents an operand of an operation.
Definition Value.h:254
Set of flags used to control the behavior of the various IR print methods (e.g.
This class provides the API for ops that are known to be isolated from above.
This class provides the API for ops that are known to be terminators.
This class indicates that the regions associated with this op don't have terminators.
This class implements the operand iterators for the Operation class.
Definition ValueRange.h:44
type_range getType() const
Operation is the basic unit of execution within MLIR.
Definition Operation.h:87
Dialect * getDialect()
Return the dialect this operation is associated with, or nullptr if the associated dialect is not loa...
Definition Operation.h:237
Region & getRegion(unsigned index)
Returns the region held by this operation at position 'index'.
Definition Operation.h:711
bool hasTrait()
Returns true if the operation was registered with a particular trait, e.g.
Definition Operation.h:774
Block * getBlock()
Returns the operation block that contains this operation.
Definition Operation.h:230
unsigned getNumRegions()
Returns the number of regions held by this operation.
Definition Operation.h:699
Location getLoc()
The source location the operation was defined or derived from.
Definition Operation.h:240
Operation * getParentOp()
Returns the closest surrounding operation that contains this operation or nullptr if this is a top-le...
Definition Operation.h:251
InFlightDiagnostic emitError(const Twine &message={})
Emit an error about fatal conditions with this operation, reporting up to any diagnostic handlers tha...
OpTy getParentOfType()
Return the closest surrounding parent operation that is of type 'OpTy'.
Definition Operation.h:255
MutableArrayRef< Region > getRegions()
Returns the regions held by this operation.
Definition Operation.h:702
operand_range getOperands()
Returns an iterator on the underlying Value's.
Definition Operation.h:403
std::enable_if_t< llvm::function_traits< std::decay_t< FnT > >::num_args==1, RetT > walk(FnT &&callback)
Walk the operation by calling the callback for each nested operation (including this one),...
Definition Operation.h:822
user_range getUsers()
Returns a range of all users.
Definition Operation.h:898
Region * getParentRegion()
Returns the region to which the instruction belongs.
Definition Operation.h:247
MLIRContext * getContext()
Return the context this operation is associated with.
Definition Operation.h:233
InFlightDiagnostic emitOpError(const Twine &message={})
Emit an error with the op name prefixed, like "'dim' op " which is convenient for verifiers.
A special type of RewriterBase that coordinates the application of a rewrite pattern on the current I...
This class contains a list of basic blocks and a link to the parent operation it is attached to.
Definition Region.h:26
Block & front()
Definition Region.h:65
BlockArgListType getArguments()
Definition Region.h:81
OpIterator op_begin()
Return iterators that walk the operations nested directly within this region.
Definition Region.h:170
bool isAncestor(Region *other)
Return true if this region is ancestor of the other region.
Definition Region.h:233
iterator_range< OpIterator > getOps()
Definition Region.h:172
bool empty()
Definition Region.h:60
unsigned getNumArguments()
Definition Region.h:123
Location getLoc()
Return a location for this region.
Definition Region.cpp:31
BlockArgument getArgument(unsigned i)
Definition Region.h:124
Operation * getParentOp()
Return the parent operation this region is attached to.
Definition Region.h:200
BlockListType & getBlocks()
Definition Region.h:45
virtual void eraseOp(Operation *op)
This method erases an operation that is known to have no uses.
OpTy replaceOpWithNewOp(Operation *op, Args &&...args)
Replace the results of the given (original) op with a new op that is created without verification (re...
Resource * getResource() const
Return the resource that the effect applies to.
EffectT * getEffect() const
Return the effect being applied.
This class represents a collection of SymbolTables.
virtual Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
static Operation * lookupNearestSymbolFrom(Operation *from, StringAttr symbol)
Returns the operation registered with the given symbol name within the closest parent operation of,...
This class provides an abstraction over the various different ranges of value types.
Definition TypeRange.h:40
Instances of the Type class are uniqued, have an immutable identifier and an optional mutable compone...
Definition Types.h:74
This class provides an abstraction over the different types of ranges over Values.
Definition ValueRange.h:389
type_range getType() const
This class represents an instance of an SSA value in the MLIR system, representing a computable value...
Definition Value.h:96
MLIRContext * getContext() const
Utility to get the associated MLIRContext that this value is defined in.
Definition Value.h:108
Type getType() const
Return the type of this value.
Definition Value.h:105
use_range getUses() const
Returns a range of all uses, which is useful for iterating over all uses.
Definition Value.h:188
Operation * getDefiningOp() const
If this value is the result of an operation, return the operation that defines it.
Definition Value.cpp:18
A utility result that is used to signal how to proceed with an ongoing walk:
Definition WalkResult.h:29
static WalkResult skip()
Definition WalkResult.h:48
static WalkResult advance()
Definition WalkResult.h:47
static WalkResult interrupt()
Definition WalkResult.h:46
static DenseArrayAttrImpl get(MLIRContext *context, ArrayRef< bool > content)
bool isReachableFromEntry(Block *a) const
Return true if the specified block is reachable from the entry block of its region.
Operation * getOwner() const
Return the owner of this operand.
Definition UseDefLists.h:38
SideEffects::EffectInstance< Effect > EffectInstance
TargetEnterDataOperands TargetEnterExitUpdateDataOperands
omp.target_enter_data, omp.target_exit_data and omp.target_update take the same clauses,...
std::tuple< NewCliOp, OpOperand *, OpOperand * > decodeCli(mlir::Value cli)
Find the omp.new_cli, generator, and consumer of a canonical loop info.
detail::InFlightRemark failed(Location loc, RemarkOpts opts)
Report an optimization remark that failed.
Definition Remarks.h:717
Include the generated interface declarations.
bool matchPattern(Value value, const Pattern &pattern)
Entry point for matching a pattern over a Value.
Definition Matchers.h:490
detail::DenseArrayAttrImpl< int64_t > DenseI64ArrayAttr
function_ref< void(Value, StringRef)> OpAsmSetValueNameFn
A functor used to set the name of the start of a result group of an operation.
Type getType(OpFoldResult ofr)
Returns the int type of the integer in ofr.
Definition Utils.cpp:307
llvm::DenseSet< ValueT, ValueInfoT > DenseSet
Definition LLVM.h:122
InFlightDiagnostic emitError(Location loc)
Utility method to emit an error message using this location.
bool isPure(Operation *op)
Returns true if the given operation is pure, i.e., is speculatable that does not touch memory.
detail::constant_int_predicate_matcher m_Zero()
Matches a constant scalar / vector splat / tensor splat integer zero.
Definition Matchers.h:442
std::conditional_t< std::is_same_v< Ty, mlir::Type >, mlir::Value, detail::TypedValue< Ty > > TypedValue
If Ty is mlir::Type this will select Value instead of having a wrapper around it.
Definition Value.h:494
llvm::TypeSwitch< T, ResultT > TypeSwitch
Definition LLVM.h:139
SmallVector< Loops, 8 > tile(ArrayRef< scf::ForOp > forOps, ArrayRef< Value > sizes, ArrayRef< scf::ForOp > targets)
Performs tiling fo imperfectly nested loops (with interchange) by strip-mining the forOps by sizes an...
Definition Utils.cpp:1330
detail::DenseArrayAttrImpl< bool > DenseBoolArrayAttr
detail::constant_op_matcher m_Constant()
Matches a constant foldable operation.
Definition Matchers.h:369
function_ref< void(Block *, StringRef)> OpAsmSetBlockNameFn
A functor used to set the name of blocks in regions directly nested under an operation.
This is the representation of an operand reference.
This class provides APIs and verifiers for ops with regions having a single block.
This represents an operation in an abstracted form, suitable for use with the builder APIs.
void addOperands(ValueRange newOperands)
void addAttributes(ArrayRef< NamedAttribute > newAttributes)
Add an array of named attributes.
void addTypes(ArrayRef< Type > newTypes)
Region * addRegion()
Create a region that should be attached to the operation.